1use std::collections::BTreeMap;
2
3use serde::de::{Error as DeError, MapAccess, Visitor};
4use serde::ser::SerializeMap;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use serde_json::{Map, Number, Value};
7
8use crate::AttachmentRef;
9
10const TAG_KEY: &str = "$lash_tool_value";
11const ATTACHMENT_TAG: &str = "attachment";
12const OBJECT_TAG: &str = "object";
13const REF_KEY: &str = "ref";
14const ENTRIES_KEY: &str = "entries";
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17pub struct ToolCallOutput {
18 pub outcome: ToolCallOutcome,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub control: Option<ToolControl>,
21}
22
23impl ToolCallOutput {
24 pub fn success(value: impl Into<ToolValue>) -> Self {
25 Self {
26 outcome: ToolCallOutcome::Success(value.into()),
27 control: None,
28 }
29 }
30
31 pub fn failure(failure: ToolFailure) -> Self {
32 Self {
33 outcome: ToolCallOutcome::Failure(failure),
34 control: None,
35 }
36 }
37
38 pub fn cancelled(cancellation: ToolCancellation) -> Self {
39 Self {
40 outcome: ToolCallOutcome::Cancelled(cancellation),
41 control: None,
42 }
43 }
44
45 pub fn with_control(mut self, control: ToolControl) -> Self {
46 self.control = Some(control);
47 self
48 }
49
50 pub fn is_success(&self) -> bool {
51 matches!(self.outcome, ToolCallOutcome::Success(_))
52 }
53
54 pub fn status(&self) -> ToolCallStatus {
55 match self.outcome {
56 ToolCallOutcome::Success(_) => ToolCallStatus::Success,
57 ToolCallOutcome::Failure(_) => ToolCallStatus::Failure,
58 ToolCallOutcome::Cancelled(_) => ToolCallStatus::Cancelled,
59 }
60 }
61
62 pub fn value_for_projection(&self) -> Value {
63 match &self.outcome {
64 ToolCallOutcome::Success(value) => value.to_json_value(),
65 ToolCallOutcome::Failure(failure) => failure.to_json_value(),
66 ToolCallOutcome::Cancelled(cancellation) => cancellation.to_json_value(),
67 }
68 }
69
70 pub fn attachments(&self) -> Vec<AttachmentRef> {
71 match &self.outcome {
72 ToolCallOutcome::Success(value) => value.attachments(),
73 ToolCallOutcome::Failure(failure) => failure
74 .raw
75 .as_ref()
76 .map(ToolValue::attachments)
77 .unwrap_or_default(),
78 ToolCallOutcome::Cancelled(cancellation) => cancellation
79 .raw
80 .as_ref()
81 .map(ToolValue::attachments)
82 .unwrap_or_default(),
83 }
84 }
85}
86
87#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
88#[serde(rename_all = "snake_case")]
89pub enum ToolCallStatus {
90 Success,
91 Failure,
92 Cancelled,
93}
94
95#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
96#[serde(tag = "status", content = "payload", rename_all = "snake_case")]
97pub enum ToolCallOutcome {
98 Success(ToolValue),
99 Failure(ToolFailure),
100 Cancelled(ToolCancellation),
101}
102
103#[derive(Clone, Debug, PartialEq)]
104pub enum ToolValue {
105 Null,
106 Bool(bool),
107 Number(Number),
108 String(String),
109 Array(Vec<ToolValue>),
110 Object(BTreeMap<String, ToolValue>),
111 Attachment(AttachmentRef),
112}
113
114impl ToolValue {
115 pub fn to_json_value(&self) -> Value {
116 serde_json::to_value(self).unwrap_or(Value::Null)
117 }
118
119 pub fn from_json_value(value: Value) -> serde_json::Result<Self> {
120 serde_json::from_value(value)
121 }
122
123 pub fn attachments(&self) -> Vec<AttachmentRef> {
124 let mut attachments = Vec::new();
125 self.collect_attachments(&mut attachments);
126 attachments
127 }
128
129 pub fn model_parts(&self) -> Vec<ModelToolReturnPart> {
130 let mut parts = Vec::new();
131 match self {
132 Self::String(text) => push_text_part(&mut parts, text.clone()),
133 Self::Attachment(reference) => {
134 parts.push(ModelToolReturnPart::Attachment(reference.clone()))
135 }
136 Self::Null | Self::Bool(_) | Self::Number(_) | Self::Array(_) | Self::Object(_) => {
137 self.push_compact_model_parts(&mut parts);
138 }
139 }
140 parts
141 }
142
143 fn collect_attachments(&self, attachments: &mut Vec<AttachmentRef>) {
144 match self {
145 Self::Attachment(reference) => attachments.push(reference.clone()),
146 Self::Array(values) => {
147 for value in values {
148 value.collect_attachments(attachments);
149 }
150 }
151 Self::Object(entries) => {
152 for value in entries.values() {
153 value.collect_attachments(attachments);
154 }
155 }
156 Self::Null | Self::Bool(_) | Self::Number(_) | Self::String(_) => {}
157 }
158 }
159
160 fn push_compact_model_parts(&self, parts: &mut Vec<ModelToolReturnPart>) {
161 match self {
162 Self::Null => push_text_part(parts, "null"),
163 Self::Bool(value) => push_text_part(parts, value.to_string()),
164 Self::Number(value) => push_text_part(parts, value.to_string()),
165 Self::String(value) => push_text_part(
166 parts,
167 serde_json::to_string(value).unwrap_or_else(|_| "\"\"".into()),
168 ),
169 Self::Attachment(reference) => {
170 parts.push(ModelToolReturnPart::Attachment(reference.clone()))
171 }
172 Self::Array(values) => {
173 push_text_part(parts, "[");
174 for (index, value) in values.iter().enumerate() {
175 if index > 0 {
176 push_text_part(parts, ",");
177 }
178 value.push_compact_model_parts(parts);
179 }
180 push_text_part(parts, "]");
181 }
182 Self::Object(entries) => {
183 push_text_part(parts, "{");
184 for (index, (key, value)) in entries.iter().enumerate() {
185 if index > 0 {
186 push_text_part(parts, ",");
187 }
188 push_text_part(
189 parts,
190 serde_json::to_string(key).unwrap_or_else(|_| "\"\"".into()),
191 );
192 push_text_part(parts, ":");
193 value.push_compact_model_parts(parts);
194 }
195 push_text_part(parts, "}");
196 }
197 }
198 }
199}
200
201impl From<Value> for ToolValue {
202 fn from(value: Value) -> Self {
203 match value {
204 Value::Null => Self::Null,
205 Value::Bool(value) => Self::Bool(value),
206 Value::Number(value) => Self::Number(value),
207 Value::String(value) => Self::String(value),
208 Value::Array(values) => Self::Array(values.into_iter().map(Self::from).collect()),
209 Value::Object(values) => Self::Object(
210 values
211 .into_iter()
212 .map(|(key, value)| (key, Self::from(value)))
213 .collect(),
214 ),
215 }
216 }
217}
218
219impl From<&str> for ToolValue {
220 fn from(value: &str) -> Self {
221 Self::String(value.to_string())
222 }
223}
224
225impl From<String> for ToolValue {
226 fn from(value: String) -> Self {
227 Self::String(value)
228 }
229}
230
231impl Serialize for ToolValue {
232 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233 where
234 S: Serializer,
235 {
236 match self {
237 Self::Null => serializer.serialize_none(),
238 Self::Bool(value) => serializer.serialize_bool(*value),
239 Self::Number(value) => value.serialize(serializer),
240 Self::String(value) => serializer.serialize_str(value),
241 Self::Array(values) => values.serialize(serializer),
242 Self::Attachment(reference) => {
243 let mut map = serializer.serialize_map(Some(2))?;
244 map.serialize_entry(TAG_KEY, ATTACHMENT_TAG)?;
245 map.serialize_entry(REF_KEY, reference)?;
246 map.end()
247 }
248 Self::Object(entries) => {
249 if entries.contains_key(TAG_KEY) {
250 let mut map = serializer.serialize_map(Some(2))?;
251 map.serialize_entry(TAG_KEY, OBJECT_TAG)?;
252 map.serialize_entry(ENTRIES_KEY, entries)?;
253 map.end()
254 } else {
255 entries.serialize(serializer)
256 }
257 }
258 }
259 }
260}
261
262impl<'de> Deserialize<'de> for ToolValue {
263 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
264 where
265 D: Deserializer<'de>,
266 {
267 struct ToolValueVisitor;
268
269 impl<'de> Visitor<'de> for ToolValueVisitor {
270 type Value = ToolValue;
271
272 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 formatter.write_str("a Lash tool value")
274 }
275
276 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E> {
277 Ok(ToolValue::Bool(value))
278 }
279
280 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> {
281 Ok(ToolValue::Number(Number::from(value)))
282 }
283
284 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> {
285 Ok(ToolValue::Number(Number::from(value)))
286 }
287
288 fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
289 where
290 E: DeError,
291 {
292 Number::from_f64(value)
293 .map(ToolValue::Number)
294 .ok_or_else(|| E::custom("non-finite number is not a valid tool value"))
295 }
296
297 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> {
298 Ok(ToolValue::String(value.to_string()))
299 }
300
301 fn visit_string<E>(self, value: String) -> Result<Self::Value, E> {
302 Ok(ToolValue::String(value))
303 }
304
305 fn visit_none<E>(self) -> Result<Self::Value, E> {
306 Ok(ToolValue::Null)
307 }
308
309 fn visit_unit<E>(self) -> Result<Self::Value, E> {
310 Ok(ToolValue::Null)
311 }
312
313 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
314 where
315 A: serde::de::SeqAccess<'de>,
316 {
317 let mut values = Vec::new();
318 while let Some(value) = seq.next_element()? {
319 values.push(value);
320 }
321 Ok(ToolValue::Array(values))
322 }
323
324 fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
325 where
326 A: MapAccess<'de>,
327 {
328 let mut map = Map::new();
329 while let Some((key, value)) = access.next_entry::<String, Value>()? {
330 map.insert(key, value);
331 }
332 decode_object(map).map_err(A::Error::custom)
333 }
334 }
335
336 deserializer.deserialize_any(ToolValueVisitor)
337 }
338}
339
340fn decode_object(mut map: Map<String, Value>) -> serde_json::Result<ToolValue> {
341 let Some(tag) = map.get(TAG_KEY) else {
342 return Ok(ToolValue::Object(
343 map.into_iter()
344 .map(|(key, value)| Ok((key, ToolValue::from_json_value(value)?)))
345 .collect::<serde_json::Result<_>>()?,
346 ));
347 };
348 let tag = tag
349 .as_str()
350 .ok_or_else(|| serde_json::Error::custom("reserved tool value tag must be a string"))?;
351 match tag {
352 ATTACHMENT_TAG => {
353 if map.len() != 2 || !map.contains_key(REF_KEY) {
354 return Err(serde_json::Error::custom("malformed attachment tool value"));
355 }
356 let reference = serde_json::from_value(
357 map.remove(REF_KEY)
358 .ok_or_else(|| serde_json::Error::custom("missing attachment ref"))?,
359 )?;
360 Ok(ToolValue::Attachment(reference))
361 }
362 OBJECT_TAG => {
363 if map.len() != 2 || !map.contains_key(ENTRIES_KEY) {
364 return Err(serde_json::Error::custom(
365 "malformed escaped object tool value",
366 ));
367 }
368 serde_json::from_value(
369 map.remove(ENTRIES_KEY)
370 .ok_or_else(|| serde_json::Error::custom("missing escaped object entries"))?,
371 )
372 .map(ToolValue::Object)
373 }
374 other => Err(serde_json::Error::custom(format!(
375 "unknown reserved tool value tag `{other}`"
376 ))),
377 }
378}
379
380#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
381pub struct ToolFailure {
382 pub class: ToolFailureClass,
383 pub code: String,
384 pub message: String,
385 pub source: ToolFailureSource,
386 pub retry: ToolRetryDisposition,
387 #[serde(default, skip_serializing_if = "Option::is_none")]
388 pub raw: Option<ToolValue>,
389}
390
391impl ToolFailure {
392 pub fn new(
393 class: ToolFailureClass,
394 code: impl Into<String>,
395 message: impl Into<String>,
396 ) -> Self {
397 Self {
398 class,
399 code: code.into(),
400 message: message.into(),
401 source: ToolFailureSource::Runtime,
402 retry: ToolRetryDisposition::Never,
403 raw: None,
404 }
405 }
406
407 pub fn runtime(
408 class: ToolFailureClass,
409 code: impl Into<String>,
410 message: impl Into<String>,
411 ) -> Self {
412 Self::new(class, code, message)
413 }
414
415 pub fn tool(
416 class: ToolFailureClass,
417 code: impl Into<String>,
418 message: impl Into<String>,
419 ) -> Self {
420 Self {
421 source: ToolFailureSource::Tool,
422 ..Self::new(class, code, message)
423 }
424 }
425
426 pub fn safe_retry(
427 class: ToolFailureClass,
428 code: impl Into<String>,
429 message: impl Into<String>,
430 after_ms: Option<u64>,
431 ) -> Self {
432 let mut failure = Self::tool(class, code, message);
433 failure.retry = ToolRetryDisposition::Safe { after_ms };
434 failure
435 }
436
437 pub fn with_retry(mut self, retry: ToolRetryDisposition) -> Self {
438 self.retry = retry;
439 self
440 }
441
442 pub fn to_json_value(&self) -> Value {
443 serde_json::to_value(self).unwrap_or_else(|_| Value::String(self.message.clone()))
444 }
445}
446
447#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
448#[serde(rename_all = "snake_case")]
449pub enum ToolFailureClass {
450 InvalidRequest,
451 Unavailable,
452 PermissionDenied,
453 Timeout,
454 Execution,
455 External,
456 ResourceLimit,
457 Internal,
458}
459
460#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
461#[serde(rename_all = "snake_case")]
462pub enum ToolFailureSource {
463 Runtime,
464 Tool,
465 Plugin,
466 Policy,
467 Cancellation,
468}
469
470#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
471#[serde(tag = "type", rename_all = "snake_case")]
472pub enum ToolRetryDisposition {
473 Never,
474 Safe {
475 #[serde(default, skip_serializing_if = "Option::is_none")]
476 after_ms: Option<u64>,
477 },
478 Exhausted {
479 attempts: u32,
480 },
481}
482
483#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
484pub struct ToolCancellation {
485 pub message: String,
486 pub source: ToolFailureSource,
487 #[serde(default, skip_serializing_if = "Option::is_none")]
488 pub raw: Option<ToolValue>,
489}
490
491impl ToolCancellation {
492 pub fn runtime(message: impl Into<String>) -> Self {
493 Self {
494 message: message.into(),
495 source: ToolFailureSource::Cancellation,
496 raw: None,
497 }
498 }
499
500 pub fn to_json_value(&self) -> Value {
501 serde_json::to_value(self).unwrap_or_else(|_| Value::String(self.message.clone()))
502 }
503}
504
505#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
506#[serde(tag = "type", rename_all = "snake_case")]
507pub enum ToolControl {
508 Handoff { session_id: String },
509 Finish { value: ToolValue },
510 Fail { failure: ToolFailure },
511}
512
513#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
514pub struct ModelToolReturn {
515 pub call_id: String,
516 pub tool_name: String,
517 pub parts: Vec<ModelToolReturnPart>,
518}
519
520impl ModelToolReturn {
521 pub fn from_output(call_id: String, tool_name: String, output: &ToolCallOutput) -> Self {
522 let parts = model_parts_from_tool_output(output);
523 Self {
524 call_id,
525 tool_name,
526 parts,
527 }
528 }
529
530 pub fn text(call_id: String, tool_name: String, content: impl Into<String>) -> Self {
531 Self {
532 call_id,
533 tool_name,
534 parts: vec![ModelToolReturnPart::Text(content.into())],
535 }
536 }
537}
538
539#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
540#[serde(tag = "type", rename_all = "snake_case")]
541pub enum ModelToolReturnPart {
542 Text(String),
543 Attachment(AttachmentRef),
544}
545
546pub fn model_parts_from_tool_output(output: &ToolCallOutput) -> Vec<ModelToolReturnPart> {
547 match &output.outcome {
548 ToolCallOutcome::Success(value) => value.model_parts(),
549 ToolCallOutcome::Failure(failure) => {
550 let mut parts = vec![ModelToolReturnPart::Text(format_failure_message(failure))];
551 if let Some(raw) = &failure.raw {
552 parts.extend(
553 raw.attachments()
554 .into_iter()
555 .map(ModelToolReturnPart::Attachment),
556 );
557 }
558 parts
559 }
560 ToolCallOutcome::Cancelled(cancellation) => {
561 let mut parts = vec![ModelToolReturnPart::Text(format_cancellation_message(
562 cancellation,
563 ))];
564 if let Some(raw) = &cancellation.raw {
565 parts.extend(
566 raw.attachments()
567 .into_iter()
568 .map(ModelToolReturnPart::Attachment),
569 );
570 }
571 parts
572 }
573 }
574}
575
576fn push_text_part(parts: &mut Vec<ModelToolReturnPart>, text: impl Into<String>) {
577 let text = text.into();
578 if text.is_empty() {
579 return;
580 }
581 if let Some(ModelToolReturnPart::Text(existing)) = parts.last_mut() {
582 existing.push_str(&text);
583 } else {
584 parts.push(ModelToolReturnPart::Text(text));
585 }
586}
587
588fn format_failure_message(failure: &ToolFailure) -> String {
589 if failure.message.is_empty() {
590 "[Tool execution failed]".to_string()
591 } else {
592 format!("[Tool execution failed]\n{}", failure.message)
593 }
594}
595
596fn format_cancellation_message(cancellation: &ToolCancellation) -> String {
597 if cancellation.message.is_empty() {
598 "[Tool execution cancelled]".to_string()
599 } else {
600 format!("[Tool execution cancelled]\n{}", cancellation.message)
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use crate::{AttachmentId, AttachmentMeta, ImageMediaType, MediaType};
608
609 fn image_ref(id: &str) -> AttachmentRef {
610 AttachmentMeta::new(
611 AttachmentId::new(id),
612 MediaType::Image(ImageMediaType::Png),
613 3,
614 Some(1),
615 Some(1),
616 Some("tiny".to_string()),
617 )
618 .as_ref()
619 }
620
621 #[test]
622 fn tool_value_serializes_nested_attachments() {
623 let value = ToolValue::Array(vec![ToolValue::Attachment(image_ref("img"))]);
624
625 let json = serde_json::to_value(&value).unwrap();
626
627 assert_eq!(json[0][TAG_KEY], ATTACHMENT_TAG);
628 assert_eq!(json[0][REF_KEY]["id"], "img");
629 assert_eq!(serde_json::from_value::<ToolValue>(json).unwrap(), value);
630 }
631
632 #[test]
633 fn tool_value_escapes_user_reserved_key() {
634 let value = ToolValue::Object(BTreeMap::from([(
635 TAG_KEY.to_string(),
636 ToolValue::String("user".into()),
637 )]));
638
639 let json = serde_json::to_value(&value).unwrap();
640
641 assert_eq!(json[TAG_KEY], OBJECT_TAG);
642 assert!(json[ENTRIES_KEY].is_object());
643 assert_eq!(serde_json::from_value::<ToolValue>(json).unwrap(), value);
644 }
645
646 #[test]
647 fn tool_value_rejects_malformed_reserved_object() {
648 let json = serde_json::json!({ TAG_KEY: ATTACHMENT_TAG, "extra": true });
649
650 assert!(serde_json::from_value::<ToolValue>(json).is_err());
651 }
652
653 #[test]
654 fn tool_value_model_parts_preserve_attachment_position() {
655 let value = ToolValue::Array(vec![
656 ToolValue::String("before".into()),
657 ToolValue::Attachment(image_ref("img")),
658 ToolValue::String("after".into()),
659 ]);
660
661 assert_eq!(
662 value.model_parts(),
663 vec![
664 ModelToolReturnPart::Text("[\"before\",".into()),
665 ModelToolReturnPart::Attachment(image_ref("img")),
666 ModelToolReturnPart::Text(",\"after\"]".into()),
667 ]
668 );
669 }
670
671 #[test]
672 fn tool_output_failure_projects_raw_attachments_after_failure_text() {
673 let attachment = image_ref("img");
674 let output = ToolCallOutput::failure(ToolFailure {
675 class: ToolFailureClass::Execution,
676 code: "boom".into(),
677 message: "boom".into(),
678 source: ToolFailureSource::Tool,
679 retry: ToolRetryDisposition::Never,
680 raw: Some(ToolValue::Object(BTreeMap::from([(
681 "image".into(),
682 ToolValue::Attachment(attachment.clone()),
683 )]))),
684 });
685
686 assert_eq!(
687 model_parts_from_tool_output(&output),
688 vec![
689 ModelToolReturnPart::Text("[Tool execution failed]\nboom".into()),
690 ModelToolReturnPart::Attachment(attachment),
691 ]
692 );
693 }
694
695 #[test]
696 fn tool_output_status_distinguishes_cancelled_from_failure() {
697 let failure = ToolCallOutput::failure(ToolFailure::tool(
698 ToolFailureClass::Execution,
699 "boom",
700 "boom",
701 ));
702 let cancelled = ToolCallOutput::cancelled(ToolCancellation::runtime("stopped"));
703
704 assert_eq!(failure.status(), ToolCallStatus::Failure);
705 assert_eq!(cancelled.status(), ToolCallStatus::Cancelled);
706 assert!(!cancelled.is_success());
707 }
708}