Skip to main content

lash_sansio/
tool_output.rs

1use std::collections::BTreeMap;
2
3use serde::de::{Error as DeError, MapAccess, Visitor};
4use serde::ser::SerializeMap;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use serde_json::{Map, Number, Value};
7
8use crate::AttachmentRef;
9
10const TAG_KEY: &str = "$lash_tool_value";
11const ATTACHMENT_TAG: &str = "attachment";
12const OBJECT_TAG: &str = "object";
13const REF_KEY: &str = "ref";
14const ENTRIES_KEY: &str = "entries";
15
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17pub struct ToolCallOutput {
18    pub outcome: ToolCallOutcome,
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub control: Option<ToolControl>,
21}
22
23impl ToolCallOutput {
24    pub fn success(value: impl Into<ToolValue>) -> Self {
25        Self {
26            outcome: ToolCallOutcome::Success(value.into()),
27            control: None,
28        }
29    }
30
31    pub fn failure(failure: ToolFailure) -> Self {
32        Self {
33            outcome: ToolCallOutcome::Failure(failure),
34            control: None,
35        }
36    }
37
38    pub fn cancelled(cancellation: ToolCancellation) -> Self {
39        Self {
40            outcome: ToolCallOutcome::Cancelled(cancellation),
41            control: None,
42        }
43    }
44
45    pub fn with_control(mut self, control: ToolControl) -> Self {
46        self.control = Some(control);
47        self
48    }
49
50    pub fn is_success(&self) -> bool {
51        matches!(self.outcome, ToolCallOutcome::Success(_))
52    }
53
54    pub fn status(&self) -> ToolCallStatus {
55        match self.outcome {
56            ToolCallOutcome::Success(_) => ToolCallStatus::Success,
57            ToolCallOutcome::Failure(_) => ToolCallStatus::Failure,
58            ToolCallOutcome::Cancelled(_) => ToolCallStatus::Cancelled,
59        }
60    }
61
62    pub fn value_for_projection(&self) -> Value {
63        match &self.outcome {
64            ToolCallOutcome::Success(value) => value.to_json_value(),
65            ToolCallOutcome::Failure(failure) => failure.to_json_value(),
66            ToolCallOutcome::Cancelled(cancellation) => cancellation.to_json_value(),
67        }
68    }
69
70    pub fn attachments(&self) -> Vec<AttachmentRef> {
71        match &self.outcome {
72            ToolCallOutcome::Success(value) => value.attachments(),
73            ToolCallOutcome::Failure(failure) => failure
74                .raw
75                .as_ref()
76                .map(ToolValue::attachments)
77                .unwrap_or_default(),
78            ToolCallOutcome::Cancelled(cancellation) => cancellation
79                .raw
80                .as_ref()
81                .map(ToolValue::attachments)
82                .unwrap_or_default(),
83        }
84    }
85}
86
87#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
88#[serde(rename_all = "snake_case")]
89pub enum ToolCallStatus {
90    Success,
91    Failure,
92    Cancelled,
93}
94
95#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
96#[serde(tag = "status", content = "payload", rename_all = "snake_case")]
97pub enum ToolCallOutcome {
98    Success(ToolValue),
99    Failure(ToolFailure),
100    Cancelled(ToolCancellation),
101}
102
103#[derive(Clone, Debug, PartialEq)]
104pub enum ToolValue {
105    Null,
106    Bool(bool),
107    Number(Number),
108    String(String),
109    Array(Vec<ToolValue>),
110    Object(BTreeMap<String, ToolValue>),
111    Attachment(AttachmentRef),
112}
113
114impl ToolValue {
115    pub fn to_json_value(&self) -> Value {
116        serde_json::to_value(self).unwrap_or(Value::Null)
117    }
118
119    pub fn from_json_value(value: Value) -> serde_json::Result<Self> {
120        serde_json::from_value(value)
121    }
122
123    pub fn attachments(&self) -> Vec<AttachmentRef> {
124        let mut attachments = Vec::new();
125        self.collect_attachments(&mut attachments);
126        attachments
127    }
128
129    pub fn model_parts(&self) -> Vec<ModelToolReturnPart> {
130        let mut parts = Vec::new();
131        match self {
132            Self::String(text) => push_text_part(&mut parts, text.clone()),
133            Self::Attachment(reference) => {
134                parts.push(ModelToolReturnPart::Attachment(reference.clone()))
135            }
136            Self::Null | Self::Bool(_) | Self::Number(_) | Self::Array(_) | Self::Object(_) => {
137                self.push_compact_model_parts(&mut parts);
138            }
139        }
140        parts
141    }
142
143    fn collect_attachments(&self, attachments: &mut Vec<AttachmentRef>) {
144        match self {
145            Self::Attachment(reference) => attachments.push(reference.clone()),
146            Self::Array(values) => {
147                for value in values {
148                    value.collect_attachments(attachments);
149                }
150            }
151            Self::Object(entries) => {
152                for value in entries.values() {
153                    value.collect_attachments(attachments);
154                }
155            }
156            Self::Null | Self::Bool(_) | Self::Number(_) | Self::String(_) => {}
157        }
158    }
159
160    fn push_compact_model_parts(&self, parts: &mut Vec<ModelToolReturnPart>) {
161        match self {
162            Self::Null => push_text_part(parts, "null"),
163            Self::Bool(value) => push_text_part(parts, value.to_string()),
164            Self::Number(value) => push_text_part(parts, value.to_string()),
165            Self::String(value) => push_text_part(
166                parts,
167                serde_json::to_string(value).unwrap_or_else(|_| "\"\"".into()),
168            ),
169            Self::Attachment(reference) => {
170                parts.push(ModelToolReturnPart::Attachment(reference.clone()))
171            }
172            Self::Array(values) => {
173                push_text_part(parts, "[");
174                for (index, value) in values.iter().enumerate() {
175                    if index > 0 {
176                        push_text_part(parts, ",");
177                    }
178                    value.push_compact_model_parts(parts);
179                }
180                push_text_part(parts, "]");
181            }
182            Self::Object(entries) => {
183                push_text_part(parts, "{");
184                for (index, (key, value)) in entries.iter().enumerate() {
185                    if index > 0 {
186                        push_text_part(parts, ",");
187                    }
188                    push_text_part(
189                        parts,
190                        serde_json::to_string(key).unwrap_or_else(|_| "\"\"".into()),
191                    );
192                    push_text_part(parts, ":");
193                    value.push_compact_model_parts(parts);
194                }
195                push_text_part(parts, "}");
196            }
197        }
198    }
199}
200
201impl From<Value> for ToolValue {
202    fn from(value: Value) -> Self {
203        match value {
204            Value::Null => Self::Null,
205            Value::Bool(value) => Self::Bool(value),
206            Value::Number(value) => Self::Number(value),
207            Value::String(value) => Self::String(value),
208            Value::Array(values) => Self::Array(values.into_iter().map(Self::from).collect()),
209            Value::Object(values) => Self::Object(
210                values
211                    .into_iter()
212                    .map(|(key, value)| (key, Self::from(value)))
213                    .collect(),
214            ),
215        }
216    }
217}
218
219impl From<&str> for ToolValue {
220    fn from(value: &str) -> Self {
221        Self::String(value.to_string())
222    }
223}
224
225impl From<String> for ToolValue {
226    fn from(value: String) -> Self {
227        Self::String(value)
228    }
229}
230
231impl Serialize for ToolValue {
232    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233    where
234        S: Serializer,
235    {
236        match self {
237            Self::Null => serializer.serialize_none(),
238            Self::Bool(value) => serializer.serialize_bool(*value),
239            Self::Number(value) => value.serialize(serializer),
240            Self::String(value) => serializer.serialize_str(value),
241            Self::Array(values) => values.serialize(serializer),
242            Self::Attachment(reference) => {
243                let mut map = serializer.serialize_map(Some(2))?;
244                map.serialize_entry(TAG_KEY, ATTACHMENT_TAG)?;
245                map.serialize_entry(REF_KEY, reference)?;
246                map.end()
247            }
248            Self::Object(entries) => {
249                if entries.contains_key(TAG_KEY) {
250                    let mut map = serializer.serialize_map(Some(2))?;
251                    map.serialize_entry(TAG_KEY, OBJECT_TAG)?;
252                    map.serialize_entry(ENTRIES_KEY, entries)?;
253                    map.end()
254                } else {
255                    entries.serialize(serializer)
256                }
257            }
258        }
259    }
260}
261
262impl<'de> Deserialize<'de> for ToolValue {
263    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
264    where
265        D: Deserializer<'de>,
266    {
267        struct ToolValueVisitor;
268
269        impl<'de> Visitor<'de> for ToolValueVisitor {
270            type Value = ToolValue;
271
272            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273                formatter.write_str("a Lash tool value")
274            }
275
276            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E> {
277                Ok(ToolValue::Bool(value))
278            }
279
280            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> {
281                Ok(ToolValue::Number(Number::from(value)))
282            }
283
284            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> {
285                Ok(ToolValue::Number(Number::from(value)))
286            }
287
288            fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
289            where
290                E: DeError,
291            {
292                Number::from_f64(value)
293                    .map(ToolValue::Number)
294                    .ok_or_else(|| E::custom("non-finite number is not a valid tool value"))
295            }
296
297            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> {
298                Ok(ToolValue::String(value.to_string()))
299            }
300
301            fn visit_string<E>(self, value: String) -> Result<Self::Value, E> {
302                Ok(ToolValue::String(value))
303            }
304
305            fn visit_none<E>(self) -> Result<Self::Value, E> {
306                Ok(ToolValue::Null)
307            }
308
309            fn visit_unit<E>(self) -> Result<Self::Value, E> {
310                Ok(ToolValue::Null)
311            }
312
313            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
314            where
315                A: serde::de::SeqAccess<'de>,
316            {
317                let mut values = Vec::new();
318                while let Some(value) = seq.next_element()? {
319                    values.push(value);
320                }
321                Ok(ToolValue::Array(values))
322            }
323
324            fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
325            where
326                A: MapAccess<'de>,
327            {
328                let mut map = Map::new();
329                while let Some((key, value)) = access.next_entry::<String, Value>()? {
330                    map.insert(key, value);
331                }
332                decode_object(map).map_err(A::Error::custom)
333            }
334        }
335
336        deserializer.deserialize_any(ToolValueVisitor)
337    }
338}
339
340fn decode_object(mut map: Map<String, Value>) -> serde_json::Result<ToolValue> {
341    let Some(tag) = map.get(TAG_KEY) else {
342        return Ok(ToolValue::Object(
343            map.into_iter()
344                .map(|(key, value)| Ok((key, ToolValue::from_json_value(value)?)))
345                .collect::<serde_json::Result<_>>()?,
346        ));
347    };
348    let tag = tag
349        .as_str()
350        .ok_or_else(|| serde_json::Error::custom("reserved tool value tag must be a string"))?;
351    match tag {
352        ATTACHMENT_TAG => {
353            if map.len() != 2 || !map.contains_key(REF_KEY) {
354                return Err(serde_json::Error::custom("malformed attachment tool value"));
355            }
356            let reference = serde_json::from_value(
357                map.remove(REF_KEY)
358                    .ok_or_else(|| serde_json::Error::custom("missing attachment ref"))?,
359            )?;
360            Ok(ToolValue::Attachment(reference))
361        }
362        OBJECT_TAG => {
363            if map.len() != 2 || !map.contains_key(ENTRIES_KEY) {
364                return Err(serde_json::Error::custom(
365                    "malformed escaped object tool value",
366                ));
367            }
368            serde_json::from_value(
369                map.remove(ENTRIES_KEY)
370                    .ok_or_else(|| serde_json::Error::custom("missing escaped object entries"))?,
371            )
372            .map(ToolValue::Object)
373        }
374        other => Err(serde_json::Error::custom(format!(
375            "unknown reserved tool value tag `{other}`"
376        ))),
377    }
378}
379
380#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
381pub struct ToolFailure {
382    pub class: ToolFailureClass,
383    pub code: String,
384    pub message: String,
385    pub source: ToolFailureSource,
386    pub retry: ToolRetryDisposition,
387    #[serde(default, skip_serializing_if = "Option::is_none")]
388    pub raw: Option<ToolValue>,
389}
390
391impl ToolFailure {
392    pub fn new(
393        class: ToolFailureClass,
394        code: impl Into<String>,
395        message: impl Into<String>,
396    ) -> Self {
397        Self {
398            class,
399            code: code.into(),
400            message: message.into(),
401            source: ToolFailureSource::Runtime,
402            retry: ToolRetryDisposition::Never,
403            raw: None,
404        }
405    }
406
407    pub fn runtime(
408        class: ToolFailureClass,
409        code: impl Into<String>,
410        message: impl Into<String>,
411    ) -> Self {
412        Self::new(class, code, message)
413    }
414
415    pub fn tool(
416        class: ToolFailureClass,
417        code: impl Into<String>,
418        message: impl Into<String>,
419    ) -> Self {
420        Self {
421            source: ToolFailureSource::Tool,
422            ..Self::new(class, code, message)
423        }
424    }
425
426    pub fn safe_retry(
427        class: ToolFailureClass,
428        code: impl Into<String>,
429        message: impl Into<String>,
430        after_ms: Option<u64>,
431    ) -> Self {
432        let mut failure = Self::tool(class, code, message);
433        failure.retry = ToolRetryDisposition::Safe { after_ms };
434        failure
435    }
436
437    pub fn with_retry(mut self, retry: ToolRetryDisposition) -> Self {
438        self.retry = retry;
439        self
440    }
441
442    pub fn to_json_value(&self) -> Value {
443        serde_json::to_value(self).unwrap_or_else(|_| Value::String(self.message.clone()))
444    }
445}
446
447#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
448#[serde(rename_all = "snake_case")]
449pub enum ToolFailureClass {
450    InvalidRequest,
451    Unavailable,
452    PermissionDenied,
453    Timeout,
454    Execution,
455    External,
456    ResourceLimit,
457    Internal,
458}
459
460#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
461#[serde(rename_all = "snake_case")]
462pub enum ToolFailureSource {
463    Runtime,
464    Tool,
465    Plugin,
466    Policy,
467    Cancellation,
468}
469
470#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
471#[serde(tag = "type", rename_all = "snake_case")]
472pub enum ToolRetryDisposition {
473    Never,
474    Safe {
475        #[serde(default, skip_serializing_if = "Option::is_none")]
476        after_ms: Option<u64>,
477    },
478    Exhausted {
479        attempts: u32,
480    },
481}
482
483#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
484pub struct ToolCancellation {
485    pub message: String,
486    pub source: ToolFailureSource,
487    #[serde(default, skip_serializing_if = "Option::is_none")]
488    pub raw: Option<ToolValue>,
489}
490
491impl ToolCancellation {
492    pub fn runtime(message: impl Into<String>) -> Self {
493        Self {
494            message: message.into(),
495            source: ToolFailureSource::Cancellation,
496            raw: None,
497        }
498    }
499
500    pub fn to_json_value(&self) -> Value {
501        serde_json::to_value(self).unwrap_or_else(|_| Value::String(self.message.clone()))
502    }
503}
504
505#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
506#[serde(tag = "type", rename_all = "snake_case")]
507pub enum ToolControl {
508    Handoff { session_id: String },
509    Finish { value: ToolValue },
510    Fail { failure: ToolFailure },
511}
512
513#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
514pub struct ModelToolReturn {
515    pub call_id: String,
516    pub tool_name: String,
517    pub parts: Vec<ModelToolReturnPart>,
518}
519
520impl ModelToolReturn {
521    pub fn from_output(call_id: String, tool_name: String, output: &ToolCallOutput) -> Self {
522        let parts = model_parts_from_tool_output(output);
523        Self {
524            call_id,
525            tool_name,
526            parts,
527        }
528    }
529
530    pub fn text(call_id: String, tool_name: String, content: impl Into<String>) -> Self {
531        Self {
532            call_id,
533            tool_name,
534            parts: vec![ModelToolReturnPart::Text(content.into())],
535        }
536    }
537}
538
539#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
540#[serde(tag = "type", rename_all = "snake_case")]
541pub enum ModelToolReturnPart {
542    Text(String),
543    Attachment(AttachmentRef),
544}
545
546pub fn model_parts_from_tool_output(output: &ToolCallOutput) -> Vec<ModelToolReturnPart> {
547    match &output.outcome {
548        ToolCallOutcome::Success(value) => value.model_parts(),
549        ToolCallOutcome::Failure(failure) => {
550            let mut parts = vec![ModelToolReturnPart::Text(format_failure_message(failure))];
551            if let Some(raw) = &failure.raw {
552                parts.extend(
553                    raw.attachments()
554                        .into_iter()
555                        .map(ModelToolReturnPart::Attachment),
556                );
557            }
558            parts
559        }
560        ToolCallOutcome::Cancelled(cancellation) => {
561            let mut parts = vec![ModelToolReturnPart::Text(format_cancellation_message(
562                cancellation,
563            ))];
564            if let Some(raw) = &cancellation.raw {
565                parts.extend(
566                    raw.attachments()
567                        .into_iter()
568                        .map(ModelToolReturnPart::Attachment),
569                );
570            }
571            parts
572        }
573    }
574}
575
576fn push_text_part(parts: &mut Vec<ModelToolReturnPart>, text: impl Into<String>) {
577    let text = text.into();
578    if text.is_empty() {
579        return;
580    }
581    if let Some(ModelToolReturnPart::Text(existing)) = parts.last_mut() {
582        existing.push_str(&text);
583    } else {
584        parts.push(ModelToolReturnPart::Text(text));
585    }
586}
587
588fn format_failure_message(failure: &ToolFailure) -> String {
589    if failure.message.is_empty() {
590        "[Tool execution failed]".to_string()
591    } else {
592        format!("[Tool execution failed]\n{}", failure.message)
593    }
594}
595
596fn format_cancellation_message(cancellation: &ToolCancellation) -> String {
597    if cancellation.message.is_empty() {
598        "[Tool execution cancelled]".to_string()
599    } else {
600        format!("[Tool execution cancelled]\n{}", cancellation.message)
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use crate::{AttachmentId, AttachmentMeta, ImageMediaType, MediaType};
608
609    fn image_ref(id: &str) -> AttachmentRef {
610        AttachmentMeta::new(
611            AttachmentId::new(id),
612            MediaType::Image(ImageMediaType::Png),
613            3,
614            Some(1),
615            Some(1),
616            Some("tiny".to_string()),
617        )
618        .as_ref()
619    }
620
621    #[test]
622    fn tool_value_serializes_nested_attachments() {
623        let value = ToolValue::Array(vec![ToolValue::Attachment(image_ref("img"))]);
624
625        let json = serde_json::to_value(&value).unwrap();
626
627        assert_eq!(json[0][TAG_KEY], ATTACHMENT_TAG);
628        assert_eq!(json[0][REF_KEY]["id"], "img");
629        assert_eq!(serde_json::from_value::<ToolValue>(json).unwrap(), value);
630    }
631
632    #[test]
633    fn tool_value_escapes_user_reserved_key() {
634        let value = ToolValue::Object(BTreeMap::from([(
635            TAG_KEY.to_string(),
636            ToolValue::String("user".into()),
637        )]));
638
639        let json = serde_json::to_value(&value).unwrap();
640
641        assert_eq!(json[TAG_KEY], OBJECT_TAG);
642        assert!(json[ENTRIES_KEY].is_object());
643        assert_eq!(serde_json::from_value::<ToolValue>(json).unwrap(), value);
644    }
645
646    #[test]
647    fn tool_value_rejects_malformed_reserved_object() {
648        let json = serde_json::json!({ TAG_KEY: ATTACHMENT_TAG, "extra": true });
649
650        assert!(serde_json::from_value::<ToolValue>(json).is_err());
651    }
652
653    #[test]
654    fn tool_value_model_parts_preserve_attachment_position() {
655        let value = ToolValue::Array(vec![
656            ToolValue::String("before".into()),
657            ToolValue::Attachment(image_ref("img")),
658            ToolValue::String("after".into()),
659        ]);
660
661        assert_eq!(
662            value.model_parts(),
663            vec![
664                ModelToolReturnPart::Text("[\"before\",".into()),
665                ModelToolReturnPart::Attachment(image_ref("img")),
666                ModelToolReturnPart::Text(",\"after\"]".into()),
667            ]
668        );
669    }
670
671    #[test]
672    fn tool_output_failure_projects_raw_attachments_after_failure_text() {
673        let attachment = image_ref("img");
674        let output = ToolCallOutput::failure(ToolFailure {
675            class: ToolFailureClass::Execution,
676            code: "boom".into(),
677            message: "boom".into(),
678            source: ToolFailureSource::Tool,
679            retry: ToolRetryDisposition::Never,
680            raw: Some(ToolValue::Object(BTreeMap::from([(
681                "image".into(),
682                ToolValue::Attachment(attachment.clone()),
683            )]))),
684        });
685
686        assert_eq!(
687            model_parts_from_tool_output(&output),
688            vec![
689                ModelToolReturnPart::Text("[Tool execution failed]\nboom".into()),
690                ModelToolReturnPart::Attachment(attachment),
691            ]
692        );
693    }
694
695    #[test]
696    fn tool_output_status_distinguishes_cancelled_from_failure() {
697        let failure = ToolCallOutput::failure(ToolFailure::tool(
698            ToolFailureClass::Execution,
699            "boom",
700            "boom",
701        ));
702        let cancelled = ToolCallOutput::cancelled(ToolCancellation::runtime("stopped"));
703
704        assert_eq!(failure.status(), ToolCallStatus::Failure);
705        assert_eq!(cancelled.status(), ToolCallStatus::Cancelled);
706        assert!(!cancelled.is_success());
707    }
708}