1use std::{borrow::Borrow, fmt};
2
3use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned};
4use serde_json::value::RawValue;
5use thiserror::Error;
6
7use crate::transcript::CommittedTurn;
8
9#[derive(Clone, Debug, Default)]
10pub struct ModelInput {
11 items: Vec<ModelInputItem>,
12}
13
14impl ModelInput {
15 pub fn new() -> Self {
16 Self { items: Vec::new() }
17 }
18
19 pub fn from_items(items: Vec<ModelInputItem>) -> Self {
20 Self { items }
21 }
22
23 pub fn items(&self) -> &[ModelInputItem] {
24 &self.items
25 }
26
27 pub fn into_items(self) -> Vec<ModelInputItem> {
28 self.items
29 }
30
31 pub fn push(&mut self, item: ModelInputItem) {
32 self.items.push(item);
33 }
34
35 pub fn system(mut self, text: impl Into<String>) -> Self {
36 self.push(ModelInputItem::text(InputMessageRole::System, text));
37 self
38 }
39
40 pub fn developer(mut self, text: impl Into<String>) -> Self {
41 self.push(ModelInputItem::text(InputMessageRole::Developer, text));
42 self
43 }
44
45 pub fn user(mut self, text: impl Into<String>) -> Self {
46 self.push(ModelInputItem::text(InputMessageRole::User, text));
47 self
48 }
49
50 pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
51 self.push(ModelInputItem::assistant_text(text));
52 self
53 }
54
55 pub fn assistant_reasoning(mut self, text: impl Into<String>) -> Self {
56 self.push(ModelInputItem::assistant_reasoning(text));
57 self
58 }
59
60 pub fn assistant_refusal(mut self, text: impl Into<String>) -> Self {
61 self.push(ModelInputItem::assistant_refusal(text));
62 self
63 }
64
65 pub fn tool_use(mut self, tool_use: ToolUse) -> Self {
66 self.push(ModelInputItem::tool_use(tool_use));
67 self
68 }
69
70 pub fn validate(&self) -> Result<(), ModelInputValidationError> {
71 if self.items.is_empty() {
72 return Err(ModelInputValidationError::Empty);
73 }
74
75 let mut tool_uses = std::collections::BTreeSet::new();
76 for item in &self.items {
77 if let ModelInputItem::ToolUse(tool_use) = item
78 && !tool_uses.insert(tool_use.id.clone())
79 {
80 return Err(ModelInputValidationError::DuplicateToolUseId {
81 id: tool_use.id.clone(),
82 });
83 }
84 }
85
86 Ok(())
87 }
88}
89
90impl From<Vec<ModelInputItem>> for ModelInput {
91 fn from(items: Vec<ModelInputItem>) -> Self {
92 Self::from_items(items)
93 }
94}
95
96#[derive(Clone, Debug)]
97pub enum ModelInputItem {
98 Message {
99 role: InputMessageRole,
100 content: NonEmpty<MessageContent>,
101 },
102 Assistant(AssistantInputItem),
103 ToolUse(ToolUse),
104 Turn(CommittedTurn),
105}
106
107impl ModelInputItem {
108 pub fn message(role: InputMessageRole, content: NonEmpty<MessageContent>) -> Self {
109 Self::Message { role, content }
110 }
111
112 pub fn text(role: InputMessageRole, text: impl Into<String>) -> Self {
113 Self::Message {
114 role,
115 content: NonEmpty::one(MessageContent::Text(text.into())),
116 }
117 }
118
119 pub fn assistant(item: AssistantInputItem) -> Self {
120 Self::Assistant(item)
121 }
122
123 pub fn assistant_text(text: impl Into<String>) -> Self {
124 Self::Assistant(AssistantInputItem::Text(text.into()))
125 }
126
127 pub fn assistant_reasoning(text: impl Into<String>) -> Self {
128 Self::Assistant(AssistantInputItem::Reasoning(text.into()))
129 }
130
131 pub fn assistant_refusal(text: impl Into<String>) -> Self {
132 Self::Assistant(AssistantInputItem::Refusal(text.into()))
133 }
134
135 pub fn tool_use(tool_use: ToolUse) -> Self {
136 Self::ToolUse(tool_use)
137 }
138
139 pub fn turn(committed_turn: CommittedTurn) -> Self {
140 Self::Turn(committed_turn)
141 }
142
143 pub fn tool_use_parts(
144 id: impl Into<ToolCallId>,
145 name: impl Into<ToolName>,
146 arguments: RawJson,
147 result: RawJson,
148 ) -> Self {
149 Self::ToolUse(ToolUse::new(id, name, arguments, result))
150 }
151}
152
153#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
154pub enum InputMessageRole {
155 System,
156 Developer,
157 User,
158}
159
160#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
161pub enum MessageContent {
162 Text(String),
163}
164
165#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
166pub enum AssistantInputItem {
171 Text(String),
172 Reasoning(String),
173 Refusal(String),
174}
175
176#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
177pub struct ToolUse {
178 pub id: ToolCallId,
179 pub name: ToolName,
180 pub arguments: RawJson,
181 pub result: RawJson,
182}
183
184impl ToolUse {
185 pub fn new(
186 id: impl Into<ToolCallId>,
187 name: impl Into<ToolName>,
188 arguments: RawJson,
189 result: RawJson,
190 ) -> Self {
191 Self {
192 id: id.into(),
193 name: name.into(),
194 arguments,
195 result,
196 }
197 }
198}
199
200#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
201pub struct ToolMetadata {
202 pub id: ToolCallId,
203 pub name: ToolName,
204 pub arguments: RawJson,
205}
206
207impl ToolMetadata {
208 pub fn new(id: impl Into<ToolCallId>, name: impl Into<ToolName>, arguments: RawJson) -> Self {
209 Self {
210 id: id.into(),
211 name: name.into(),
212 arguments,
213 }
214 }
215
216 pub fn into_tool_use(self, result: RawJson) -> ToolUse {
217 ToolUse::new(self.id, self.name, self.arguments, result)
218 }
219}
220
221#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
222pub struct AssistantTurn {
227 items: NonEmpty<AssistantTurnItem>,
228}
229
230impl AssistantTurn {
231 pub fn new(items: NonEmpty<AssistantTurnItem>) -> Self {
232 Self { items }
233 }
234
235 pub fn from_items(items: Vec<AssistantTurnItem>) -> Result<Self, EmptyNonEmptyError> {
236 Ok(Self::new(NonEmpty::try_from_vec(items)?))
237 }
238
239 pub fn items(&self) -> &[AssistantTurnItem] {
240 self.items.as_slice()
241 }
242
243 pub fn items_non_empty(&self) -> &NonEmpty<AssistantTurnItem> {
244 &self.items
245 }
246
247 pub fn into_items(self) -> NonEmpty<AssistantTurnItem> {
248 self.items
249 }
250
251 pub fn text(text: impl Into<String>) -> Self {
252 Self::new(NonEmpty::one(AssistantTurnItem::Text(text.into())))
253 }
254
255 pub fn reasoning(text: impl Into<String>) -> Self {
256 Self::new(NonEmpty::one(AssistantTurnItem::Reasoning(text.into())))
257 }
258
259 pub fn refusal(text: impl Into<String>) -> Self {
260 Self::new(NonEmpty::one(AssistantTurnItem::Refusal(text.into())))
261 }
262
263 pub fn tool_call(
264 id: impl Into<ToolCallId>,
265 name: impl Into<ToolName>,
266 arguments: RawJson,
267 ) -> Self {
268 Self::new(NonEmpty::one(AssistantTurnItem::ToolCall {
269 id: id.into(),
270 name: name.into(),
271 arguments,
272 }))
273 }
274
275 pub fn assistant_text(&self) -> String {
276 let mut text = String::new();
277 for item in self.items() {
278 if let AssistantTurnItem::Text(delta) = item {
279 text.push_str(delta);
280 }
281 }
282 text
283 }
284}
285
286#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
287pub enum AssistantTurnItem {
292 Text(String),
293 Reasoning(String),
294 Refusal(String),
295 ToolCall {
296 id: ToolCallId,
297 name: ToolName,
298 arguments: RawJson,
299 },
300}
301
302#[derive(Clone, Debug, Eq, PartialEq)]
303pub struct NonEmpty<T>(Vec<T>);
304
305impl<T> NonEmpty<T> {
306 pub fn one(item: T) -> Self {
307 Self(vec![item])
308 }
309
310 pub fn try_from_vec(items: Vec<T>) -> Result<Self, EmptyNonEmptyError> {
311 if items.is_empty() {
312 Err(EmptyNonEmptyError)
313 } else {
314 Ok(Self(items))
315 }
316 }
317
318 pub fn as_slice(&self) -> &[T] {
319 &self.0
320 }
321
322 pub fn iter(&self) -> std::slice::Iter<'_, T> {
323 self.0.iter()
324 }
325
326 pub fn into_vec(self) -> Vec<T> {
327 self.0
328 }
329}
330
331impl<T> TryFrom<Vec<T>> for NonEmpty<T> {
332 type Error = EmptyNonEmptyError;
333
334 fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
335 Self::try_from_vec(value)
336 }
337}
338
339impl<T> From<NonEmpty<T>> for Vec<T> {
340 fn from(value: NonEmpty<T>) -> Self {
341 value.0
342 }
343}
344
345impl<T> Serialize for NonEmpty<T>
346where
347 T: Serialize,
348{
349 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
350 where
351 S: Serializer,
352 {
353 self.0.serialize(serializer)
354 }
355}
356
357impl<'de, T> Deserialize<'de> for NonEmpty<T>
358where
359 T: Deserialize<'de>,
360{
361 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
362 where
363 D: Deserializer<'de>,
364 {
365 let values = Vec::<T>::deserialize(deserializer)?;
366 Self::try_from_vec(values).map_err(serde::de::Error::custom)
367 }
368}
369
370impl<T> Borrow<[T]> for NonEmpty<T> {
371 fn borrow(&self) -> &[T] {
372 self.as_slice()
373 }
374}
375
376#[derive(Debug, Error, Clone, Copy, Eq, PartialEq)]
377#[error("non-empty collection must contain at least one element")]
378pub struct EmptyNonEmptyError;
379
380#[derive(Debug, Error, Clone, Eq, PartialEq)]
381pub enum ModelInputValidationError {
382 #[error("model input must contain at least one item")]
383 Empty,
384 #[error("duplicate tool use id `{id}` in model input")]
385 DuplicateToolUseId { id: ToolCallId },
386}
387
388#[derive(Debug, Error, Clone, Eq, PartialEq)]
389pub enum AssistantTurnInputError {
390 #[error("assistant turn references missing tool use `{id}`")]
391 MissingToolUse { id: ToolCallId },
392 #[error("assistant turn received duplicate tool use `{id}`")]
393 DuplicateToolUse { id: ToolCallId },
394 #[error("assistant turn received extra tool use `{id}`")]
395 ExtraToolUse { id: ToolCallId },
396 #[error("assistant turn tool call `{id}` expected tool name `{expected}`, got `{actual}`")]
397 MismatchedToolName {
398 id: ToolCallId,
399 expected: ToolName,
400 actual: ToolName,
401 },
402 #[error("assistant turn tool call `{id}` received mismatched arguments")]
403 MismatchedToolArguments {
404 id: ToolCallId,
405 expected: RawJson,
406 actual: RawJson,
407 },
408}
409
410#[derive(Serialize, Deserialize)]
411#[serde(transparent)]
412pub struct RawJson(Box<RawValue>);
413
414impl RawJson {
415 pub fn parse(json: impl Into<String>) -> Result<Self, serde_json::Error> {
416 RawValue::from_string(json.into()).map(Self)
417 }
418
419 pub fn from_serializable<T>(value: &T) -> Result<Self, serde_json::Error>
420 where
421 T: Serialize,
422 {
423 RawValue::from_string(serde_json::to_string(value)?).map(Self)
424 }
425
426 pub fn get(&self) -> &str {
427 self.0.get()
428 }
429
430 pub fn deserialize<T>(&self) -> Result<T, serde_json::Error>
431 where
432 T: DeserializeOwned,
433 {
434 serde_json::from_str(self.get())
435 }
436}
437
438impl Clone for RawJson {
439 fn clone(&self) -> Self {
440 Self(self.0.clone())
441 }
442}
443
444impl fmt::Debug for RawJson {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 f.debug_tuple("RawJson").field(&self.get()).finish()
447 }
448}
449
450impl PartialEq for RawJson {
451 fn eq(&self, other: &Self) -> bool {
452 self.get() == other.get()
453 }
454}
455
456impl Eq for RawJson {}
457
458impl PartialOrd for RawJson {
459 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
460 Some(self.cmp(other))
461 }
462}
463
464impl Ord for RawJson {
465 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
466 self.get().cmp(other.get())
467 }
468}
469
470impl std::hash::Hash for RawJson {
471 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
472 self.get().hash(state);
473 }
474}
475
476impl fmt::Display for RawJson {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 f.write_str(self.get())
479 }
480}
481
482impl From<Box<RawValue>> for RawJson {
483 fn from(value: Box<RawValue>) -> Self {
484 Self(value)
485 }
486}
487
488#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
489#[serde(transparent)]
490pub struct ToolCallId(String);
491
492impl ToolCallId {
493 pub fn new(id: impl Into<String>) -> Self {
494 Self(id.into())
495 }
496
497 pub fn as_str(&self) -> &str {
498 &self.0
499 }
500}
501
502impl fmt::Display for ToolCallId {
503 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504 f.write_str(&self.0)
505 }
506}
507
508impl From<String> for ToolCallId {
509 fn from(value: String) -> Self {
510 Self(value)
511 }
512}
513
514impl From<&str> for ToolCallId {
515 fn from(value: &str) -> Self {
516 Self(value.to_string())
517 }
518}
519
520#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
521#[serde(transparent)]
522pub struct ToolName(String);
523
524impl ToolName {
525 pub fn new(name: impl Into<String>) -> Self {
526 Self(name.into())
527 }
528
529 pub fn as_str(&self) -> &str {
530 &self.0
531 }
532}
533
534impl fmt::Display for ToolName {
535 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
536 f.write_str(&self.0)
537 }
538}
539
540impl From<String> for ToolName {
541 fn from(value: String) -> Self {
542 Self(value)
543 }
544}
545
546impl From<&str> for ToolName {
547 fn from(value: &str) -> Self {
548 Self(value.to_string())
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use schemars::JsonSchema;
556 use serde::{Deserialize, Serialize};
557
558 use crate::toolset::ToolInput;
559
560 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
561 struct WeatherArgs {
562 city: String,
563 }
564
565 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
566 struct WeatherResult {
567 forecast: String,
568 }
569
570 impl ToolInput for WeatherArgs {
571 type Output = WeatherResult;
572
573 const NAME: &'static str = "weather";
574 const DESCRIPTION: &'static str = "Get weather";
575 }
576
577 #[test]
578 fn raw_json_rejects_invalid_json() {
579 assert!(RawJson::parse("{").is_err());
580 assert_eq!(
581 RawJson::parse("{\"ok\":true}").unwrap().get(),
582 "{\"ok\":true}"
583 );
584 }
585
586 #[test]
587 fn non_empty_rejects_empty_vectors() {
588 assert!(NonEmpty::<String>::try_from_vec(vec![]).is_err());
589 }
590
591 #[test]
592 fn model_input_validation_rejects_duplicate_tool_use_ids() {
593 let input = ModelInput::from_items(vec![
594 ModelInputItem::text(InputMessageRole::User, "hello"),
595 ModelInputItem::tool_use_parts(
596 "call-1",
597 "weather",
598 RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
599 RawJson::parse("\"sunny\"").unwrap(),
600 ),
601 ModelInputItem::tool_use_parts(
602 "call-1",
603 "weather",
604 RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
605 RawJson::parse("\"rainy\"").unwrap(),
606 ),
607 ]);
608
609 assert_eq!(
610 input.validate().unwrap_err(),
611 ModelInputValidationError::DuplicateToolUseId {
612 id: ToolCallId::from("call-1"),
613 }
614 );
615 }
616
617 #[test]
618 fn tool_input_serializes_result() {
619 let tool_use = WeatherArgs::tool_use(
620 ToolMetadata::new(
621 "call-1",
622 "weather",
623 RawJson::parse("{\"city\":\"Tokyo\"}").unwrap(),
624 ),
625 WeatherResult {
626 forecast: "sunny".into(),
627 },
628 )
629 .unwrap();
630
631 assert_eq!(tool_use.result.get(), "{\"forecast\":\"sunny\"}");
632 }
633}