1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::utils::uuid7;
10
11#[cfg(feature = "specta")]
12use specta::Type;
13
14pub trait ToolOutputMixin {}
20
21#[cfg_attr(feature = "specta", derive(Type))]
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct ToolCall {
28 id: String,
30 name: String,
32 args: serde_json::Value,
34}
35
36impl ToolCall {
37 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
39 Self {
40 id: uuid7(None).to_string(),
41 name: name.into(),
42 args,
43 }
44 }
45
46 pub fn with_id(
48 id: impl Into<String>,
49 name: impl Into<String>,
50 args: serde_json::Value,
51 ) -> Self {
52 Self {
53 id: id.into(),
54 name: name.into(),
55 args,
56 }
57 }
58
59 pub fn id(&self) -> &str {
61 &self.id
62 }
63
64 pub fn name(&self) -> &str {
66 &self.name
67 }
68
69 pub fn args(&self) -> &serde_json::Value {
71 &self.args
72 }
73}
74
75#[cfg_attr(feature = "specta", derive(Type))]
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
81pub struct ToolCallChunk {
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub name: Option<String>,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub args: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub id: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub index: Option<i32>,
94}
95
96impl ToolCallChunk {
97 pub fn new(
99 name: Option<String>,
100 args: Option<String>,
101 id: Option<String>,
102 index: Option<i32>,
103 ) -> Self {
104 Self {
105 name,
106 args,
107 id,
108 index,
109 }
110 }
111}
112
113#[cfg_attr(feature = "specta", derive(Type))]
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119pub struct InvalidToolCall {
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub name: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub args: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub id: Option<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub error: Option<String>,
132}
133
134impl InvalidToolCall {
135 pub fn new(
137 name: Option<String>,
138 args: Option<String>,
139 id: Option<String>,
140 error: Option<String>,
141 ) -> Self {
142 Self {
143 name,
144 args,
145 id,
146 error,
147 }
148 }
149}
150
151#[cfg_attr(feature = "specta", derive(Type))]
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159pub struct ToolMessage {
160 content: String,
162 tool_call_id: String,
164 id: Option<String>,
166 #[serde(skip_serializing_if = "Option::is_none")]
168 name: Option<String>,
169 #[serde(default = "default_status")]
171 status: ToolStatus,
172 #[serde(skip_serializing_if = "Option::is_none")]
178 artifact: Option<serde_json::Value>,
179 #[serde(default)]
181 additional_kwargs: HashMap<String, serde_json::Value>,
182 #[serde(default)]
184 response_metadata: HashMap<String, serde_json::Value>,
185}
186
187fn default_status() -> ToolStatus {
188 ToolStatus::Success
189}
190
191#[cfg_attr(feature = "specta", derive(Type))]
193#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
194#[serde(rename_all = "lowercase")]
195pub enum ToolStatus {
196 #[default]
197 Success,
198 Error,
199}
200
201impl ToolMessage {
202 pub fn new(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
204 Self {
205 content: content.into(),
206 tool_call_id: tool_call_id.into(),
207 id: Some(uuid7(None).to_string()),
208 name: None,
209 status: ToolStatus::Success,
210 artifact: None,
211 additional_kwargs: HashMap::new(),
212 response_metadata: HashMap::new(),
213 }
214 }
215
216 pub fn with_id(
220 id: impl Into<String>,
221 content: impl Into<String>,
222 tool_call_id: impl Into<String>,
223 ) -> Self {
224 Self {
225 content: content.into(),
226 tool_call_id: tool_call_id.into(),
227 id: Some(id.into()),
228 name: None,
229 status: ToolStatus::Success,
230 artifact: None,
231 additional_kwargs: HashMap::new(),
232 response_metadata: HashMap::new(),
233 }
234 }
235
236 pub fn error(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
238 Self {
239 content: content.into(),
240 tool_call_id: tool_call_id.into(),
241 id: Some(uuid7(None).to_string()),
242 name: None,
243 status: ToolStatus::Error,
244 artifact: None,
245 additional_kwargs: HashMap::new(),
246 response_metadata: HashMap::new(),
247 }
248 }
249
250 pub fn with_artifact(
252 content: impl Into<String>,
253 tool_call_id: impl Into<String>,
254 artifact: serde_json::Value,
255 ) -> Self {
256 Self {
257 content: content.into(),
258 tool_call_id: tool_call_id.into(),
259 id: Some(uuid7(None).to_string()),
260 name: None,
261 status: ToolStatus::Success,
262 artifact: Some(artifact),
263 additional_kwargs: HashMap::new(),
264 response_metadata: HashMap::new(),
265 }
266 }
267
268 pub fn with_name(mut self, name: impl Into<String>) -> Self {
270 self.name = Some(name.into());
271 self
272 }
273
274 pub fn content(&self) -> &str {
276 &self.content
277 }
278
279 pub fn tool_call_id(&self) -> &str {
281 &self.tool_call_id
282 }
283
284 pub fn id(&self) -> Option<&str> {
286 self.id.as_deref()
287 }
288
289 pub fn name(&self) -> Option<&str> {
291 self.name.as_deref()
292 }
293
294 pub fn status(&self) -> &ToolStatus {
296 &self.status
297 }
298
299 pub fn artifact(&self) -> Option<&serde_json::Value> {
301 self.artifact.as_ref()
302 }
303
304 pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
306 &self.additional_kwargs
307 }
308
309 pub fn response_metadata(&self) -> &HashMap<String, serde_json::Value> {
311 &self.response_metadata
312 }
313}
314
315impl ToolOutputMixin for ToolMessage {}
316
317#[cfg_attr(feature = "specta", derive(Type))]
321#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
322pub struct ToolMessageChunk {
323 content: String,
325 tool_call_id: String,
327 id: Option<String>,
329 #[serde(skip_serializing_if = "Option::is_none")]
331 name: Option<String>,
332 #[serde(default = "default_status")]
334 status: ToolStatus,
335 #[serde(skip_serializing_if = "Option::is_none")]
337 artifact: Option<serde_json::Value>,
338 #[serde(default)]
340 additional_kwargs: HashMap<String, serde_json::Value>,
341 #[serde(default)]
343 response_metadata: HashMap<String, serde_json::Value>,
344}
345
346impl ToolMessageChunk {
347 pub fn new(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
349 Self {
350 content: content.into(),
351 tool_call_id: tool_call_id.into(),
352 id: None,
353 name: None,
354 status: ToolStatus::Success,
355 artifact: None,
356 additional_kwargs: HashMap::new(),
357 response_metadata: HashMap::new(),
358 }
359 }
360
361 pub fn content(&self) -> &str {
363 &self.content
364 }
365
366 pub fn tool_call_id(&self) -> &str {
368 &self.tool_call_id
369 }
370
371 pub fn id(&self) -> Option<&str> {
373 self.id.as_deref()
374 }
375
376 pub fn name(&self) -> Option<&str> {
378 self.name.as_deref()
379 }
380
381 pub fn status(&self) -> &ToolStatus {
383 &self.status
384 }
385
386 pub fn artifact(&self) -> Option<&serde_json::Value> {
388 self.artifact.as_ref()
389 }
390
391 pub fn concat(&self, other: &ToolMessageChunk) -> ToolMessageChunk {
393 let mut content = self.content.clone();
394 content.push_str(&other.content);
395
396 let status = if self.status == ToolStatus::Error || other.status == ToolStatus::Error {
398 ToolStatus::Error
399 } else {
400 ToolStatus::Success
401 };
402
403 ToolMessageChunk {
404 content,
405 tool_call_id: self.tool_call_id.clone(),
406 id: self.id.clone().or_else(|| other.id.clone()),
407 name: self.name.clone().or_else(|| other.name.clone()),
408 status,
409 artifact: self.artifact.clone().or_else(|| other.artifact.clone()),
410 additional_kwargs: self.additional_kwargs.clone(),
411 response_metadata: self.response_metadata.clone(),
412 }
413 }
414
415 pub fn to_message(&self) -> ToolMessage {
417 ToolMessage {
418 content: self.content.clone(),
419 tool_call_id: self.tool_call_id.clone(),
420 id: self.id.clone(),
421 name: self.name.clone(),
422 status: self.status.clone(),
423 artifact: self.artifact.clone(),
424 additional_kwargs: self.additional_kwargs.clone(),
425 response_metadata: self.response_metadata.clone(),
426 }
427 }
428}
429
430impl std::ops::Add for ToolMessageChunk {
431 type Output = ToolMessageChunk;
432
433 fn add(self, other: ToolMessageChunk) -> ToolMessageChunk {
434 self.concat(&other)
435 }
436}
437
438pub fn tool_call(name: impl Into<String>, args: serde_json::Value, id: Option<String>) -> ToolCall {
442 match id {
443 Some(id) => ToolCall::with_id(id, name, args),
444 None => ToolCall::new(name, args),
445 }
446}
447
448pub fn tool_call_chunk(
452 name: Option<String>,
453 args: Option<String>,
454 id: Option<String>,
455 index: Option<i32>,
456) -> ToolCallChunk {
457 ToolCallChunk::new(name, args, id, index)
458}
459
460pub fn invalid_tool_call(
464 name: Option<String>,
465 args: Option<String>,
466 id: Option<String>,
467 error: Option<String>,
468) -> InvalidToolCall {
469 InvalidToolCall::new(name, args, id, error)
470}
471
472pub fn default_tool_parser(
476 raw_tool_calls: &[serde_json::Value],
477) -> (Vec<ToolCall>, Vec<InvalidToolCall>) {
478 let mut tool_calls = Vec::new();
479 let mut invalid_tool_calls = Vec::new();
480
481 for raw_tool_call in raw_tool_calls {
482 let function = match raw_tool_call.get("function") {
483 Some(f) => f,
484 None => continue,
485 };
486
487 let function_name = function
488 .get("name")
489 .and_then(|n| n.as_str())
490 .unwrap_or("")
491 .to_string();
492
493 let arguments_str = function
494 .get("arguments")
495 .and_then(|a| a.as_str())
496 .unwrap_or("{}");
497
498 let id = raw_tool_call
499 .get("id")
500 .and_then(|i| i.as_str())
501 .map(|s| s.to_string());
502
503 match serde_json::from_str::<serde_json::Value>(arguments_str) {
504 Ok(args) if args.is_object() => {
505 tool_calls.push(tool_call(function_name, args, id));
506 }
507 _ => {
508 invalid_tool_calls.push(invalid_tool_call(
509 Some(function_name),
510 Some(arguments_str.to_string()),
511 id,
512 None,
513 ));
514 }
515 }
516 }
517
518 (tool_calls, invalid_tool_calls)
519}
520
521pub fn default_tool_chunk_parser(raw_tool_calls: &[serde_json::Value]) -> Vec<ToolCallChunk> {
525 let mut chunks = Vec::new();
526
527 for raw_tool_call in raw_tool_calls {
528 let (function_name, function_args) = match raw_tool_call.get("function") {
529 Some(f) => (
530 f.get("name")
531 .and_then(|n| n.as_str())
532 .map(|s| s.to_string()),
533 f.get("arguments")
534 .and_then(|a| a.as_str())
535 .map(|s| s.to_string()),
536 ),
537 None => (None, None),
538 };
539
540 let id = raw_tool_call
541 .get("id")
542 .and_then(|i| i.as_str())
543 .map(|s| s.to_string());
544
545 let index = raw_tool_call
546 .get("index")
547 .and_then(|i| i.as_i64())
548 .map(|i| i as i32);
549
550 chunks.push(tool_call_chunk(function_name, function_args, id, index));
551 }
552
553 chunks
554}