Skip to main content

lutum_protocol/
conversation.rs

1use std::{borrow::Borrow, fmt};
2
3use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned};
4use serde_json::value::RawValue;
5use thiserror::Error;
6
7use crate::transcript::CommittedTurn;
8
9#[derive(Clone, Debug, Default)]
10pub struct ModelInput {
11    items: Vec<ModelInputItem>,
12}
13
14impl ModelInput {
15    pub fn new() -> Self {
16        Self { items: Vec::new() }
17    }
18
19    pub fn from_items(items: Vec<ModelInputItem>) -> Self {
20        Self { items }
21    }
22
23    pub fn items(&self) -> &[ModelInputItem] {
24        &self.items
25    }
26
27    pub fn into_items(self) -> Vec<ModelInputItem> {
28        self.items
29    }
30
31    pub fn push(&mut self, item: ModelInputItem) {
32        self.items.push(item);
33    }
34
35    pub fn system(mut self, text: impl Into<String>) -> Self {
36        self.push(ModelInputItem::text(InputMessageRole::System, text));
37        self
38    }
39
40    pub fn developer(mut self, text: impl Into<String>) -> Self {
41        self.push(ModelInputItem::text(InputMessageRole::Developer, text));
42        self
43    }
44
45    pub fn user(mut self, text: impl Into<String>) -> Self {
46        self.push(ModelInputItem::text(InputMessageRole::User, text));
47        self
48    }
49
50    pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
51        self.push(ModelInputItem::assistant_text(text));
52        self
53    }
54
55    pub fn assistant_reasoning(mut self, text: impl Into<String>) -> Self {
56        self.push(ModelInputItem::assistant_reasoning(text));
57        self
58    }
59
60    pub fn assistant_refusal(mut self, text: impl Into<String>) -> Self {
61        self.push(ModelInputItem::assistant_refusal(text));
62        self
63    }
64
65    pub fn tool_use(mut self, tool_use: ToolUse) -> Self {
66        self.push(ModelInputItem::tool_use(tool_use));
67        self
68    }
69
70    pub fn validate(&self) -> Result<(), ModelInputValidationError> {
71        if self.items.is_empty() {
72            return Err(ModelInputValidationError::Empty);
73        }
74
75        let mut tool_uses = std::collections::BTreeSet::new();
76        for item in &self.items {
77            if let ModelInputItem::ToolUse(tool_use) = item
78                && !tool_uses.insert(tool_use.id.clone())
79            {
80                return Err(ModelInputValidationError::DuplicateToolUseId {
81                    id: tool_use.id.clone(),
82                });
83            }
84        }
85
86        Ok(())
87    }
88}
89
90impl From<Vec<ModelInputItem>> for ModelInput {
91    fn from(items: Vec<ModelInputItem>) -> Self {
92        Self::from_items(items)
93    }
94}
95
96#[derive(Clone, Debug)]
97pub enum ModelInputItem {
98    Message {
99        role: InputMessageRole,
100        content: NonEmpty<MessageContent>,
101    },
102    Assistant(AssistantInputItem),
103    ToolUse(ToolUse),
104    Turn(CommittedTurn),
105}
106
107impl ModelInputItem {
108    pub fn message(role: InputMessageRole, content: NonEmpty<MessageContent>) -> Self {
109        Self::Message { role, content }
110    }
111
112    pub fn text(role: InputMessageRole, text: impl Into<String>) -> Self {
113        Self::Message {
114            role,
115            content: NonEmpty::one(MessageContent::Text(text.into())),
116        }
117    }
118
119    pub fn assistant(item: AssistantInputItem) -> Self {
120        Self::Assistant(item)
121    }
122
123    pub fn assistant_text(text: impl Into<String>) -> Self {
124        Self::Assistant(AssistantInputItem::Text(text.into()))
125    }
126
127    pub fn assistant_reasoning(text: impl Into<String>) -> Self {
128        Self::Assistant(AssistantInputItem::Reasoning(text.into()))
129    }
130
131    pub fn assistant_refusal(text: impl Into<String>) -> Self {
132        Self::Assistant(AssistantInputItem::Refusal(text.into()))
133    }
134
135    pub fn tool_use(tool_use: ToolUse) -> Self {
136        Self::ToolUse(tool_use)
137    }
138
139    pub fn turn(committed_turn: CommittedTurn) -> Self {
140        Self::Turn(committed_turn)
141    }
142
143    pub fn tool_use_parts(
144        id: impl Into<ToolCallId>,
145        name: impl Into<ToolName>,
146        arguments: RawJson,
147        result: RawJson,
148    ) -> Self {
149        Self::ToolUse(ToolUse::new(id, name, arguments, result))
150    }
151}
152
153#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
154pub enum InputMessageRole {
155    System,
156    Developer,
157    User,
158}
159
160#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
161pub enum MessageContent {
162    Text(String),
163}
164
165#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
166/// Assistant-authored request items that can be replayed into a future model input.
167///
168/// This is intentionally narrower than [`AssistantTurnItem`]: tool calls are represented
169/// as [`ToolUse`] at the surrounding [`ModelInputItem`] level so call/result pairs stay bundled.
170pub enum AssistantInputItem {
171    Text(String),
172    Reasoning(String),
173    Refusal(String),
174}
175
176#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
177pub struct ToolUse {
178    pub id: ToolCallId,
179    pub name: ToolName,
180    pub arguments: RawJson,
181    pub result: RawJson,
182}
183
184impl ToolUse {
185    pub fn new(
186        id: impl Into<ToolCallId>,
187        name: impl Into<ToolName>,
188        arguments: RawJson,
189        result: RawJson,
190    ) -> Self {
191        Self {
192            id: id.into(),
193            name: name.into(),
194            arguments,
195            result,
196        }
197    }
198}
199
200#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
201pub struct ToolMetadata {
202    pub id: ToolCallId,
203    pub name: ToolName,
204    pub arguments: RawJson,
205}
206
207impl ToolMetadata {
208    pub fn new(id: impl Into<ToolCallId>, name: impl Into<ToolName>, arguments: RawJson) -> Self {
209        Self {
210            id: id.into(),
211            name: name.into(),
212            arguments,
213        }
214    }
215
216    pub fn into_tool_use(self, result: RawJson) -> ToolUse {
217        ToolUse::new(self.id, self.name, self.arguments, result)
218    }
219}
220
221#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
222/// Canonical assistant output for a completed turn.
223///
224/// This remains richer than [`AssistantInputItem`] because the model can emit tool calls that
225/// are not yet paired with tool results at response time.
226pub struct AssistantTurn {
227    items: NonEmpty<AssistantTurnItem>,
228}
229
230impl AssistantTurn {
231    pub fn new(items: NonEmpty<AssistantTurnItem>) -> Self {
232        Self { items }
233    }
234
235    pub fn from_items(items: Vec<AssistantTurnItem>) -> Result<Self, EmptyNonEmptyError> {
236        Ok(Self::new(NonEmpty::try_from_vec(items)?))
237    }
238
239    pub fn items(&self) -> &[AssistantTurnItem] {
240        self.items.as_slice()
241    }
242
243    pub fn items_non_empty(&self) -> &NonEmpty<AssistantTurnItem> {
244        &self.items
245    }
246
247    pub fn into_items(self) -> NonEmpty<AssistantTurnItem> {
248        self.items
249    }
250
251    pub fn text(text: impl Into<String>) -> Self {
252        Self::new(NonEmpty::one(AssistantTurnItem::Text(text.into())))
253    }
254
255    pub fn reasoning(text: impl Into<String>) -> Self {
256        Self::new(NonEmpty::one(AssistantTurnItem::Reasoning(text.into())))
257    }
258
259    pub fn refusal(text: impl Into<String>) -> Self {
260        Self::new(NonEmpty::one(AssistantTurnItem::Refusal(text.into())))
261    }
262
263    pub fn tool_call(
264        id: impl Into<ToolCallId>,
265        name: impl Into<ToolName>,
266        arguments: RawJson,
267    ) -> Self {
268        Self::new(NonEmpty::one(AssistantTurnItem::ToolCall {
269            id: id.into(),
270            name: name.into(),
271            arguments,
272        }))
273    }
274
275    pub fn assistant_text(&self) -> String {
276        let mut text = String::new();
277        for item in self.items() {
278            if let AssistantTurnItem::Text(delta) = item {
279                text.push_str(delta);
280            }
281        }
282        text
283    }
284}
285
286#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
287/// Canonical assistant output items for a completed turn.
288///
289/// Tool calls exist only on the response side. Once a call has been paired with a tool result for
290/// replay, it is represented as [`ModelInputItem::ToolUse`] instead.
291pub enum AssistantTurnItem {
292    Text(String),
293    Reasoning(String),
294    Refusal(String),
295    ToolCall {
296        id: ToolCallId,
297        name: ToolName,
298        arguments: RawJson,
299    },
300}
301
302#[derive(Clone, Debug, Eq, PartialEq)]
303pub struct NonEmpty<T>(Vec<T>);
304
305impl<T> NonEmpty<T> {
306    pub fn one(item: T) -> Self {
307        Self(vec![item])
308    }
309
310    pub fn try_from_vec(items: Vec<T>) -> Result<Self, EmptyNonEmptyError> {
311        if items.is_empty() {
312            Err(EmptyNonEmptyError)
313        } else {
314            Ok(Self(items))
315        }
316    }
317
318    pub fn as_slice(&self) -> &[T] {
319        &self.0
320    }
321
322    pub fn iter(&self) -> std::slice::Iter<'_, T> {
323        self.0.iter()
324    }
325
326    pub fn into_vec(self) -> Vec<T> {
327        self.0
328    }
329}
330
331impl<T> TryFrom<Vec<T>> for NonEmpty<T> {
332    type Error = EmptyNonEmptyError;
333
334    fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
335        Self::try_from_vec(value)
336    }
337}
338
339impl<T> From<NonEmpty<T>> for Vec<T> {
340    fn from(value: NonEmpty<T>) -> Self {
341        value.0
342    }
343}
344
345impl<T> Serialize for NonEmpty<T>
346where
347    T: Serialize,
348{
349    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
350    where
351        S: Serializer,
352    {
353        self.0.serialize(serializer)
354    }
355}
356
357impl<'de, T> Deserialize<'de> for NonEmpty<T>
358where
359    T: Deserialize<'de>,
360{
361    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
362    where
363        D: Deserializer<'de>,
364    {
365        let values = Vec::<T>::deserialize(deserializer)?;
366        Self::try_from_vec(values).map_err(serde::de::Error::custom)
367    }
368}
369
370impl<T> Borrow<[T]> for NonEmpty<T> {
371    fn borrow(&self) -> &[T] {
372        self.as_slice()
373    }
374}
375
376#[derive(Debug, Error, Clone, Copy, Eq, PartialEq)]
377#[error("non-empty collection must contain at least one element")]
378pub struct EmptyNonEmptyError;
379
380#[derive(Debug, Error, Clone, Eq, PartialEq)]
381pub enum ModelInputValidationError {
382    #[error("model input must contain at least one item")]
383    Empty,
384    #[error("duplicate tool use id `{id}` in model input")]
385    DuplicateToolUseId { id: ToolCallId },
386}
387
388#[derive(Debug, Error, Clone, Eq, PartialEq)]
389pub enum AssistantTurnInputError {
390    #[error("assistant turn references missing tool use `{id}`")]
391    MissingToolUse { id: ToolCallId },
392    #[error("assistant turn received duplicate tool use `{id}`")]
393    DuplicateToolUse { id: ToolCallId },
394    #[error("assistant turn received extra tool use `{id}`")]
395    ExtraToolUse { id: ToolCallId },
396    #[error("assistant turn tool call `{id}` expected tool name `{expected}`, got `{actual}`")]
397    MismatchedToolName {
398        id: ToolCallId,
399        expected: ToolName,
400        actual: ToolName,
401    },
402    #[error("assistant turn tool call `{id}` received mismatched arguments")]
403    MismatchedToolArguments {
404        id: ToolCallId,
405        expected: RawJson,
406        actual: RawJson,
407    },
408}
409
410#[derive(Serialize, Deserialize)]
411#[serde(transparent)]
412pub struct RawJson(Box<RawValue>);
413
414impl RawJson {
415    pub fn parse(json: impl Into<String>) -> Result<Self, serde_json::Error> {
416        RawValue::from_string(json.into()).map(Self)
417    }
418
419    pub fn from_serializable<T>(value: &T) -> Result<Self, serde_json::Error>
420    where
421        T: Serialize,
422    {
423        RawValue::from_string(serde_json::to_string(value)?).map(Self)
424    }
425
426    pub fn get(&self) -> &str {
427        self.0.get()
428    }
429
430    pub fn deserialize<T>(&self) -> Result<T, serde_json::Error>
431    where
432        T: DeserializeOwned,
433    {
434        serde_json::from_str(self.get())
435    }
436}
437
438impl Clone for RawJson {
439    fn clone(&self) -> Self {
440        Self(self.0.clone())
441    }
442}
443
444impl fmt::Debug for RawJson {
445    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446        f.debug_tuple("RawJson").field(&self.get()).finish()
447    }
448}
449
450impl PartialEq for RawJson {
451    fn eq(&self, other: &Self) -> bool {
452        self.get() == other.get()
453    }
454}
455
456impl Eq for RawJson {}
457
458impl PartialOrd for RawJson {
459    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
460        Some(self.cmp(other))
461    }
462}
463
464impl Ord for RawJson {
465    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
466        self.get().cmp(other.get())
467    }
468}
469
470impl std::hash::Hash for RawJson {
471    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
472        self.get().hash(state);
473    }
474}
475
476impl fmt::Display for RawJson {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        f.write_str(self.get())
479    }
480}
481
482impl From<Box<RawValue>> for RawJson {
483    fn from(value: Box<RawValue>) -> Self {
484        Self(value)
485    }
486}
487
488#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
489#[serde(transparent)]
490pub struct ToolCallId(String);
491
492impl ToolCallId {
493    pub fn new(id: impl Into<String>) -> Self {
494        Self(id.into())
495    }
496
497    pub fn as_str(&self) -> &str {
498        &self.0
499    }
500}
501
502impl fmt::Display for ToolCallId {
503    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504        f.write_str(&self.0)
505    }
506}
507
508impl From<String> for ToolCallId {
509    fn from(value: String) -> Self {
510        Self(value)
511    }
512}
513
514impl From<&str> for ToolCallId {
515    fn from(value: &str) -> Self {
516        Self(value.to_string())
517    }
518}
519
520#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
521#[serde(transparent)]
522pub struct ToolName(String);
523
524impl ToolName {
525    pub fn new(name: impl Into<String>) -> Self {
526        Self(name.into())
527    }
528
529    pub fn as_str(&self) -> &str {
530        &self.0
531    }
532}
533
534impl fmt::Display for ToolName {
535    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
536        f.write_str(&self.0)
537    }
538}
539
540impl From<String> for ToolName {
541    fn from(value: String) -> Self {
542        Self(value)
543    }
544}
545
546impl From<&str> for ToolName {
547    fn from(value: &str) -> Self {
548        Self(value.to_string())
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use schemars::JsonSchema;
556    use serde::{Deserialize, Serialize};
557
558    use crate::toolset::ToolInput;
559
560    #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
561    struct WeatherArgs {
562        city: String,
563    }
564
565    #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
566    struct WeatherResult {
567        forecast: String,
568    }
569
570    impl ToolInput for WeatherArgs {
571        type Output = WeatherResult;
572
573        const NAME: &'static str = "weather";
574        const DESCRIPTION: &'static str = "Get weather";
575    }
576
577    #[test]
578    fn raw_json_rejects_invalid_json() {
579        assert!(RawJson::parse("{").is_err());
580        assert_eq!(
581            RawJson::parse("{\"ok\":true}").unwrap().get(),
582            "{\"ok\":true}"
583        );
584    }
585
586    #[test]
587    fn non_empty_rejects_empty_vectors() {
588        assert!(NonEmpty::<String>::try_from_vec(vec![]).is_err());
589    }
590
591    #[test]
592    fn model_input_validation_rejects_duplicate_tool_use_ids() {
593        let input = ModelInput::from_items(vec![
594            ModelInputItem::text(InputMessageRole::User, "hello"),
595            ModelInputItem::tool_use_parts(
596                "call-1",
597                "weather",
598                RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
599                RawJson::parse("\"sunny\"").unwrap(),
600            ),
601            ModelInputItem::tool_use_parts(
602                "call-1",
603                "weather",
604                RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
605                RawJson::parse("\"rainy\"").unwrap(),
606            ),
607        ]);
608
609        assert_eq!(
610            input.validate().unwrap_err(),
611            ModelInputValidationError::DuplicateToolUseId {
612                id: ToolCallId::from("call-1"),
613            }
614        );
615    }
616
617    #[test]
618    fn tool_input_serializes_result() {
619        let tool_use = WeatherArgs::tool_use(
620            ToolMetadata::new(
621                "call-1",
622                "weather",
623                RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
624            ),
625            WeatherResult {
626                forecast: "sunny".into(),
627            },
628        )
629        .unwrap();
630
631        assert_eq!(tool_use.result.get(), "{\"forecast\":\"sunny\"}");
632    }
633}