tool_parser/parsers/
llama.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use serde_json::Value;
4
5use crate::{
6    errors::{ParserError, ParserResult},
7    parsers::helpers,
8    partial_json::PartialJson,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall},
11};
12
13/// Llama 3.2 format parser for tool calls
14///
15/// Handles the Llama 3.2 specific format:
16/// `<|python_tag|>{"name": "func", "parameters": {...}}`
17///
18/// Also supports plain JSON without the python_tag prefix
19pub struct LlamaParser {
20    /// Parser for handling incomplete JSON during streaming
21    partial_json: PartialJson,
22
23    /// Buffer for accumulating incomplete patterns across chunks
24    buffer: String,
25
26    /// Stores complete tool call info (name and arguments) for each tool being parsed
27    prev_tool_call_arr: Vec<Value>,
28
29    /// Index of currently streaming tool call (-1 means no active tool)
30    current_tool_id: i32,
31
32    /// Flag for whether current tool's name has been sent to client
33    current_tool_name_sent: bool,
34
35    /// Tracks raw JSON string content streamed to client for each tool's arguments
36    streamed_args_for_tool: Vec<String>,
37
38    /// Token configuration
39    bot_token: &'static str,
40    tool_call_separator: &'static str,
41}
42
43impl LlamaParser {
44    /// Create a new Llama parser
45    pub fn new() -> Self {
46        Self {
47            partial_json: PartialJson::default(),
48            buffer: String::new(),
49            prev_tool_call_arr: Vec::new(),
50            current_tool_id: -1,
51            current_tool_name_sent: false,
52            streamed_args_for_tool: Vec::new(),
53            bot_token: "<|python_tag|>",
54            tool_call_separator: ";",
55        }
56    }
57
58    /// Extract content after python_tag token
59    fn extract_content_after_python_tag(&self, text: &str) -> Option<(String, String)> {
60        const PYTHON_TAG: &str = "<|python_tag|>";
61
62        if let Some(tag_pos) = text.find(PYTHON_TAG) {
63            let normal_text = text[..tag_pos].to_string();
64            let json_content = text[tag_pos + PYTHON_TAG.len()..].to_string();
65            Some((normal_text, json_content))
66        } else {
67            None
68        }
69    }
70
71    /// Parse a single JSON object into a ToolCall (Llama format: name + parameters)
72    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
73        // Llama format only: {"name": "function_name", "parameters": {...}}
74        let name = obj.get("name").and_then(|v| v.as_str());
75
76        if let Some(name) = name {
77            // Llama uses "parameters" key
78            let empty_obj = Value::Object(serde_json::Map::new());
79            let parameters = obj.get("parameters").unwrap_or(&empty_obj);
80
81            // Convert parameters to JSON string
82            let arguments = serde_json::to_string(parameters)
83                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
84
85            Ok(Some(ToolCall {
86                function: FunctionCall {
87                    name: name.to_string(),
88                    arguments,
89                },
90            }))
91        } else {
92            Ok(None)
93        }
94    }
95
96    /// Parse semicolon-separated JSON objects
97    fn parse_semicolon_separated(&self, content: &str) -> ParserResult<Vec<ToolCall>> {
98        let mut all_tools = Vec::new();
99
100        // Split by semicolon and parse each JSON object
101        for part in content.split(';') {
102            let trimmed = part.trim();
103            if trimmed.is_empty() {
104                continue;
105            }
106
107            // Try to parse this part as a single JSON object
108            match serde_json::from_str::<Value>(trimmed) {
109                Ok(value) => {
110                    if let Some(tool) = self.parse_single_object(&value)? {
111                        all_tools.push(tool);
112                    }
113                }
114                Err(e) => {
115                    // Skip invalid JSON parts in semicolon-separated list
116                    tracing::debug!("Failed to parse tool call: {}", e);
117                }
118            }
119        }
120
121        Ok(all_tools)
122    }
123}
124
125impl Default for LlamaParser {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[async_trait]
132impl ToolParser for LlamaParser {
133    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
134        // Extract normal text and JSON content
135        let (normal_text, json_content) =
136            if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
137                (normal, json)
138            } else if text.trim_start().starts_with('{') {
139                (String::new(), text.to_string())
140            } else {
141                // No JSON structure found
142                return Ok((text.to_string(), vec![]));
143            };
144
145        // Parse the JSON content (may contain semicolon-separated objects)
146        let tools = if json_content.contains(';') {
147            self.parse_semicolon_separated(&json_content)?
148        } else {
149            // Try single JSON object
150            let parsed = serde_json::from_str::<Value>(json_content.trim())
151                .map_err(|e| ParserError::ParsingFailed(e.to_string()))
152                .and_then(|v| {
153                    self.parse_single_object(&v)
154                        .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
155                });
156
157            parsed.unwrap_or_else(|e| {
158                tracing::debug!("Failed to parse tool call: {:?}", e);
159                vec![]
160            })
161        };
162
163        // If we couldn't parse any tools, return the original text
164        if tools.is_empty() {
165            return Ok((text.to_string(), vec![]));
166        }
167
168        Ok((normal_text, tools))
169    }
170
171    async fn parse_incremental(
172        &mut self,
173        chunk: &str,
174        tools: &[Tool],
175    ) -> ParserResult<StreamingParseResult> {
176        // Append new text to buffer
177        self.buffer.push_str(chunk);
178        let current_text = &self.buffer.clone();
179
180        // Check if current_text has tool_call
181        let has_tool_start = self.has_tool_markers(current_text)
182            || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
183
184        if !has_tool_start {
185            // Only clear buffer if we're sure no tool call is starting
186            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
187                let normal_text = self.buffer.clone();
188                self.buffer.clear();
189
190                return Ok(StreamingParseResult {
191                    normal_text,
192                    calls: vec![],
193                });
194            } else {
195                // Might be partial bot_token, keep buffering
196                return Ok(StreamingParseResult::default());
197            }
198        }
199
200        // Build tool indices
201        let tool_indices = helpers::get_tool_indices(tools);
202
203        // Determine start index for JSON parsing
204        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
205            pos + self.bot_token.len()
206        } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
207            self.tool_call_separator.len()
208        } else {
209            0
210        };
211
212        helpers::handle_json_tool_streaming(
213            current_text,
214            start_idx,
215            &mut self.partial_json,
216            &tool_indices,
217            &mut self.buffer,
218            &mut self.current_tool_id,
219            &mut self.current_tool_name_sent,
220            &mut self.streamed_args_for_tool,
221            &mut self.prev_tool_call_arr,
222        )
223    }
224
225    fn has_tool_markers(&self, text: &str) -> bool {
226        // Llama format if contains python_tag or starts with JSON object
227        text.contains("<|python_tag|>") || text.trim_start().starts_with('{')
228    }
229
230    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::types::ToolCallItem>> {
231        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
232    }
233
234    fn reset(&mut self) {
235        helpers::reset_parser_state(
236            &mut self.buffer,
237            &mut self.prev_tool_call_arr,
238            &mut self.current_tool_id,
239            &mut self.current_tool_name_sent,
240            &mut self.streamed_args_for_tool,
241        );
242    }
243}