Skip to main content

lutum_protocol/
llm.rs

1use std::{borrow::Cow, fmt, marker::PhantomData, pin::Pin, sync::Arc};
2
3use bon::Builder;
4use futures::Stream;
5
6use crate::{
7    AgentError,
8    budget::{RequestBudget, Usage},
9    conversation::{ModelInput, RawJson, ToolMetadata},
10    structured::StructuredOutput,
11    toolset::{NoTools, ToolPolicy, Toolset},
12    transcript::CommittedTurn,
13};
14
15pub type TextTurnEventStream<T, E = AgentError> =
16    Pin<Box<dyn Stream<Item = Result<TextTurnEvent<T>, E>> + Send + Sync + 'static>>;
17pub type StructuredTurnEventStream<T, O, E = AgentError> =
18    Pin<Box<dyn Stream<Item = Result<StructuredTurnEvent<T, O>, E>> + Send + Sync + 'static>>;
19pub type StructuredCompletionEventStream<O, E = AgentError> =
20    Pin<Box<dyn Stream<Item = Result<StructuredCompletionEvent<O>, E>> + Send + Sync + 'static>>;
21pub type CompletionEventStream<E = AgentError> =
22    Pin<Box<dyn Stream<Item = Result<CompletionEvent, E>> + Send + Sync + 'static>>;
23pub type ErasedTextTurnEventStream<E = AgentError> =
24    Pin<Box<dyn Stream<Item = Result<ErasedTextTurnEvent, E>> + Send + Sync + 'static>>;
25pub type ErasedStructuredTurnEventStream<E = AgentError> =
26    Pin<Box<dyn Stream<Item = Result<ErasedStructuredTurnEvent, E>> + Send + Sync + 'static>>;
27pub type ErasedStructuredCompletionEventStream<E = AgentError> =
28    Pin<Box<dyn Stream<Item = Result<ErasedStructuredCompletionEvent, E>> + Send + Sync + 'static>>;
29
30#[derive(
31    Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, serde::Serialize, serde::Deserialize,
32)]
33#[serde(try_from = "String", into = "String")]
34pub struct ModelName(String);
35
36impl ModelName {
37    pub fn new(model: impl Into<String>) -> Result<Self, ModelNameError> {
38        let model = model.into();
39        if model.trim().is_empty() {
40            return Err(ModelNameError::Empty);
41        }
42        Ok(Self(model))
43    }
44
45    pub fn as_str(&self) -> &str {
46        &self.0
47    }
48
49    pub fn into_string(self) -> String {
50        self.0
51    }
52}
53
54impl AsRef<str> for ModelName {
55    fn as_ref(&self) -> &str {
56        self.as_str()
57    }
58}
59
60impl fmt::Display for ModelName {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.write_str(self.as_str())
63    }
64}
65
66impl From<ModelName> for String {
67    fn from(value: ModelName) -> Self {
68        value.into_string()
69    }
70}
71
72impl TryFrom<String> for ModelName {
73    type Error = ModelNameError;
74
75    fn try_from(value: String) -> Result<Self, Self::Error> {
76        Self::new(value)
77    }
78}
79
80impl TryFrom<&str> for ModelName {
81    type Error = ModelNameError;
82
83    fn try_from(value: &str) -> Result<Self, Self::Error> {
84        Self::new(value)
85    }
86}
87
88#[derive(Clone, Debug, thiserror::Error, Eq, PartialEq)]
89pub enum ModelNameError {
90    #[error("model must not be empty")]
91    Empty,
92}
93
94#[derive(Builder, Clone, Debug, Default, PartialEq)]
95#[builder(builder_type(name = GenerationParamsBuilder))]
96pub struct GenerationParams {
97    pub temperature: Option<Temperature>,
98    pub max_output_tokens: Option<u32>,
99    pub seed: Option<u64>,
100}
101
102#[derive(Builder, Clone, Debug, PartialEq)]
103#[builder(builder_type(name = TurnConfigBuilder))]
104pub struct TurnConfig<T: Toolset = NoTools> {
105    pub model: ModelName,
106    #[builder(default)]
107    pub generation: GenerationParams,
108    #[builder(default)]
109    pub tools: ToolPolicy<T>,
110    #[builder(default = RequestBudget::unlimited())]
111    pub budget: RequestBudget,
112}
113
114impl<T> TurnConfig<T>
115where
116    T: Toolset,
117{
118    pub fn new(model: ModelName) -> Self {
119        Self {
120            model,
121            generation: GenerationParams::default(),
122            tools: ToolPolicy::Disabled,
123            budget: RequestBudget::unlimited(),
124        }
125    }
126}
127
128#[derive(Builder, Clone, Debug, PartialEq)]
129#[builder(builder_type(name = TextTurnBuilder))]
130pub struct TextTurn<T: Toolset = NoTools> {
131    pub config: TurnConfig<T>,
132}
133
134impl<T> TextTurn<T>
135where
136    T: Toolset,
137{
138    pub fn new(model: ModelName) -> Self {
139        Self {
140            config: TurnConfig::new(model),
141        }
142    }
143}
144
145#[derive(Builder, Clone, Debug, PartialEq)]
146#[builder(builder_type(name = StructuredOutputSpecBuilder))]
147pub struct StructuredOutputSpec<O: StructuredOutput> {
148    #[builder(skip = PhantomData)]
149    _marker: PhantomData<fn() -> O>,
150}
151
152impl<O> Default for StructuredOutputSpec<O>
153where
154    O: StructuredOutput,
155{
156    fn default() -> Self {
157        Self {
158            _marker: PhantomData,
159        }
160    }
161}
162
163#[derive(Builder, Clone, Debug, PartialEq)]
164#[builder(builder_type(name = StructuredTurnBuilder))]
165pub struct StructuredTurn<T: Toolset, O: StructuredOutput> {
166    pub config: TurnConfig<T>,
167    #[builder(default)]
168    pub output: StructuredOutputSpec<O>,
169}
170
171impl<T, O> StructuredTurn<T, O>
172where
173    T: Toolset,
174    O: StructuredOutput,
175{
176    pub fn new(model: ModelName) -> Self {
177        Self {
178            config: TurnConfig::new(model),
179            output: StructuredOutputSpec::default(),
180        }
181    }
182}
183
184#[derive(Builder, Clone, Debug, PartialEq)]
185#[builder(builder_type(name = StructuredCompletionRequestBuilder))]
186pub struct StructuredCompletionRequest<O: StructuredOutput> {
187    pub model: ModelName,
188    pub system: Option<String>,
189    #[builder(into)]
190    pub prompt: String,
191    #[builder(default)]
192    pub generation: GenerationParams,
193    #[builder(default = RequestBudget::unlimited())]
194    pub budget: RequestBudget,
195    #[builder(default)]
196    pub output: StructuredOutputSpec<O>,
197}
198
199impl<O> StructuredCompletionRequest<O>
200where
201    O: StructuredOutput,
202{
203    pub fn new(model: ModelName, prompt: impl Into<String>) -> Self {
204        Self {
205            model,
206            system: None,
207            prompt: prompt.into(),
208            generation: GenerationParams::default(),
209            budget: RequestBudget::unlimited(),
210            output: StructuredOutputSpec::default(),
211        }
212    }
213
214    pub fn with_system(mut self, system: impl Into<String>) -> Self {
215        self.system = Some(system.into());
216        self
217    }
218}
219
220#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct AdapterToolDefinition {
222    pub name: String,
223    pub description: String,
224    pub input_schema: serde_json::Value,
225}
226
227#[derive(Clone, Debug, PartialEq, Eq)]
228pub enum AdapterToolChoice {
229    None,
230    Auto,
231    Required,
232    Specific(String),
233}
234
235#[derive(Clone, Debug, PartialEq)]
236pub struct AdapterTurnConfig {
237    pub model: ModelName,
238    pub generation: GenerationParams,
239    pub tools: Vec<AdapterToolDefinition>,
240    pub tool_choice: AdapterToolChoice,
241}
242
243#[derive(Clone, Debug, Default, Eq, PartialEq)]
244pub struct ModelSelection {
245    /// Borrowed model names must be truly `'static`; runtime-derived values
246    /// should use `Cow::Owned`.
247    pub primary: Option<Cow<'static, str>>,
248    /// Borrowed model names must be truly `'static`; runtime-derived values
249    /// should use `Cow::Owned`.
250    pub fallbacks: Option<Vec<Cow<'static, str>>>,
251}
252
253pub trait ModelSelector: Send + Sync {
254    fn select_model(&self, extensions: &crate::extensions::RequestExtensions) -> ModelSelection;
255}
256
257#[derive(Clone)]
258pub struct AdapterTextTurn {
259    pub config: AdapterTurnConfig,
260    pub extensions: Arc<crate::extensions::RequestExtensions>,
261}
262
263#[derive(Clone, Debug, PartialEq)]
264pub struct AdapterStructuredOutputSpec {
265    pub schema_name: String,
266    pub schema: serde_json::Value,
267}
268
269#[derive(Clone)]
270pub struct AdapterStructuredTurn {
271    pub config: AdapterTurnConfig,
272    pub extensions: Arc<crate::extensions::RequestExtensions>,
273    pub output: AdapterStructuredOutputSpec,
274}
275
276#[derive(Clone, Debug, PartialEq)]
277pub struct AdapterStructuredCompletionRequest {
278    pub model: ModelName,
279    pub system: Option<String>,
280    pub prompt: String,
281    pub generation: GenerationParams,
282    pub output: AdapterStructuredOutputSpec,
283}
284
285impl fmt::Debug for AdapterTextTurn {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        f.debug_struct("AdapterTextTurn")
288            .field("config", &self.config)
289            .field("extensions", &"<opaque>")
290            .finish()
291    }
292}
293
294impl PartialEq for AdapterTextTurn {
295    fn eq(&self, other: &Self) -> bool {
296        self.config == other.config && Arc::ptr_eq(&self.extensions, &other.extensions)
297    }
298}
299
300impl fmt::Debug for AdapterStructuredTurn {
301    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302        f.debug_struct("AdapterStructuredTurn")
303            .field("config", &self.config)
304            .field("extensions", &"<opaque>")
305            .field("output", &self.output)
306            .finish()
307    }
308}
309
310impl PartialEq for AdapterStructuredTurn {
311    fn eq(&self, other: &Self) -> bool {
312        self.config == other.config
313            && Arc::ptr_eq(&self.extensions, &other.extensions)
314            && self.output == other.output
315    }
316}
317
318#[derive(Clone, Copy, Debug, PartialEq)]
319pub struct Temperature(f32);
320
321impl Temperature {
322    pub const MIN: f32 = 0.0;
323    pub const MAX: f32 = 2.0;
324
325    pub fn new(value: f32) -> Result<Self, TemperatureError> {
326        if !value.is_finite() {
327            return Err(TemperatureError::NonFinite);
328        }
329        if !(Self::MIN..=Self::MAX).contains(&value) {
330            return Err(TemperatureError::OutOfRange { value });
331        }
332        Ok(Self(value))
333    }
334
335    pub fn get(self) -> f32 {
336        self.0
337    }
338}
339
340impl TryFrom<f32> for Temperature {
341    type Error = TemperatureError;
342
343    fn try_from(value: f32) -> Result<Self, Self::Error> {
344        Self::new(value)
345    }
346}
347
348#[derive(Clone, Debug, thiserror::Error, PartialEq)]
349pub enum TemperatureError {
350    #[error("temperature must be finite")]
351    NonFinite,
352    #[error("temperature {value} must be in the range [0.0, 2.0]")]
353    OutOfRange { value: f32 },
354}
355
356#[derive(Builder, Clone, Debug, PartialEq)]
357#[builder(builder_type(name = CompletionRequestBuilder))]
358pub struct CompletionRequest {
359    pub model: ModelName,
360    #[builder(into)]
361    pub prompt: String,
362    #[builder(default)]
363    pub options: CompletionOptions,
364    #[builder(default = RequestBudget::unlimited())]
365    pub budget: RequestBudget,
366}
367
368impl CompletionRequest {
369    pub fn new(model: ModelName, prompt: impl Into<String>) -> Self {
370        Self {
371            model,
372            prompt: prompt.into(),
373            options: CompletionOptions::default(),
374            budget: RequestBudget::unlimited(),
375        }
376    }
377
378    pub fn with_options(mut self, options: CompletionOptions) -> Self {
379        self.options = options;
380        self
381    }
382
383    pub fn with_budget(mut self, budget: RequestBudget) -> Self {
384        self.budget = budget;
385        self
386    }
387}
388
389#[derive(Builder, Clone, Debug, Default, PartialEq)]
390#[builder(builder_type(name = CompletionOptionsBuilder))]
391pub struct CompletionOptions {
392    pub temperature: Option<Temperature>,
393    pub max_output_tokens: Option<u32>,
394    #[builder(default)]
395    pub stop: Vec<String>,
396}
397
398#[derive(Clone, Debug)]
399pub enum TextTurnEvent<T: Toolset> {
400    Started {
401        request_id: Option<String>,
402        model: String,
403    },
404    TextDelta {
405        delta: String,
406    },
407    ReasoningDelta {
408        delta: String,
409    },
410    RefusalDelta {
411        delta: String,
412    },
413    ToolCallChunk {
414        id: crate::conversation::ToolCallId,
415        name: crate::conversation::ToolName,
416        arguments_json_delta: String,
417    },
418    ToolCallReady(T::ToolCall),
419    Completed {
420        request_id: Option<String>,
421        finish_reason: FinishReason,
422        usage: Usage,
423        committed_turn: CommittedTurn,
424    },
425}
426
427#[derive(Clone, Debug)]
428pub enum StructuredTurnEvent<T: Toolset, O: StructuredOutput> {
429    Started {
430        request_id: Option<String>,
431        model: String,
432    },
433    StructuredOutputChunk {
434        json_delta: String,
435    },
436    StructuredOutputReady(O),
437    ReasoningDelta {
438        delta: String,
439    },
440    RefusalDelta {
441        delta: String,
442    },
443    ToolCallChunk {
444        id: crate::conversation::ToolCallId,
445        name: crate::conversation::ToolName,
446        arguments_json_delta: String,
447    },
448    ToolCallReady(T::ToolCall),
449    Completed {
450        request_id: Option<String>,
451        finish_reason: FinishReason,
452        usage: Usage,
453        committed_turn: CommittedTurn,
454    },
455}
456
457#[derive(Clone, Debug)]
458pub enum ErasedTextTurnEvent {
459    Started {
460        request_id: Option<String>,
461        model: String,
462    },
463    TextDelta {
464        delta: String,
465    },
466    ReasoningDelta {
467        delta: String,
468    },
469    RefusalDelta {
470        delta: String,
471    },
472    ToolCallChunk {
473        id: crate::conversation::ToolCallId,
474        name: crate::conversation::ToolName,
475        arguments_json_delta: String,
476    },
477    ToolCallReady(ToolMetadata),
478    Completed {
479        request_id: Option<String>,
480        finish_reason: FinishReason,
481        usage: Usage,
482        committed_turn: CommittedTurn,
483    },
484}
485
486#[derive(Clone, Debug)]
487pub enum ErasedStructuredTurnEvent {
488    Started {
489        request_id: Option<String>,
490        model: String,
491    },
492    StructuredOutputChunk {
493        json_delta: String,
494    },
495    StructuredOutputReady(RawJson),
496    ReasoningDelta {
497        delta: String,
498    },
499    RefusalDelta {
500        delta: String,
501    },
502    ToolCallChunk {
503        id: crate::conversation::ToolCallId,
504        name: crate::conversation::ToolName,
505        arguments_json_delta: String,
506    },
507    ToolCallReady(ToolMetadata),
508    Completed {
509        request_id: Option<String>,
510        finish_reason: FinishReason,
511        usage: Usage,
512        committed_turn: CommittedTurn,
513    },
514}
515
516impl<T> PartialEq for TextTurnEvent<T>
517where
518    T: Toolset,
519    T::ToolCall: PartialEq,
520{
521    fn eq(&self, other: &Self) -> bool {
522        match (self, other) {
523            (
524                Self::Started {
525                    request_id: lhs_request_id,
526                    model: lhs_model,
527                },
528                Self::Started {
529                    request_id: rhs_request_id,
530                    model: rhs_model,
531                },
532            ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
533            (Self::TextDelta { delta: lhs }, Self::TextDelta { delta: rhs }) => lhs == rhs,
534            (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
535                lhs == rhs
536            }
537            (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
538            (
539                Self::ToolCallChunk {
540                    id: lhs_id,
541                    name: lhs_name,
542                    arguments_json_delta: lhs_delta,
543                },
544                Self::ToolCallChunk {
545                    id: rhs_id,
546                    name: rhs_name,
547                    arguments_json_delta: rhs_delta,
548                },
549            ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
550            (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
551            (
552                Self::Completed {
553                    request_id: lhs_request_id,
554                    finish_reason: lhs_finish_reason,
555                    usage: lhs_usage,
556                    committed_turn: lhs_committed_turn,
557                },
558                Self::Completed {
559                    request_id: rhs_request_id,
560                    finish_reason: rhs_finish_reason,
561                    usage: rhs_usage,
562                    committed_turn: rhs_committed_turn,
563                },
564            ) => {
565                lhs_request_id == rhs_request_id
566                    && lhs_finish_reason == rhs_finish_reason
567                    && lhs_usage == rhs_usage
568                    && Arc::ptr_eq(lhs_committed_turn, rhs_committed_turn)
569            }
570            _ => false,
571        }
572    }
573}
574
575impl<T, O> PartialEq for StructuredTurnEvent<T, O>
576where
577    T: Toolset,
578    T::ToolCall: PartialEq,
579    O: StructuredOutput + PartialEq,
580{
581    fn eq(&self, other: &Self) -> bool {
582        match (self, other) {
583            (
584                Self::Started {
585                    request_id: lhs_request_id,
586                    model: lhs_model,
587                },
588                Self::Started {
589                    request_id: rhs_request_id,
590                    model: rhs_model,
591                },
592            ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
593            (
594                Self::StructuredOutputChunk { json_delta: lhs },
595                Self::StructuredOutputChunk { json_delta: rhs },
596            ) => lhs == rhs,
597            (Self::StructuredOutputReady(lhs), Self::StructuredOutputReady(rhs)) => lhs == rhs,
598            (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
599                lhs == rhs
600            }
601            (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
602            (
603                Self::ToolCallChunk {
604                    id: lhs_id,
605                    name: lhs_name,
606                    arguments_json_delta: lhs_delta,
607                },
608                Self::ToolCallChunk {
609                    id: rhs_id,
610                    name: rhs_name,
611                    arguments_json_delta: rhs_delta,
612                },
613            ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
614            (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
615            (
616                Self::Completed {
617                    request_id: lhs_request_id,
618                    finish_reason: lhs_finish_reason,
619                    usage: lhs_usage,
620                    committed_turn: lhs_committed_turn,
621                },
622                Self::Completed {
623                    request_id: rhs_request_id,
624                    finish_reason: rhs_finish_reason,
625                    usage: rhs_usage,
626                    committed_turn: rhs_committed_turn,
627                },
628            ) => {
629                lhs_request_id == rhs_request_id
630                    && lhs_finish_reason == rhs_finish_reason
631                    && lhs_usage == rhs_usage
632                    && Arc::ptr_eq(lhs_committed_turn, rhs_committed_turn)
633            }
634            _ => false,
635        }
636    }
637}
638
639impl PartialEq for ErasedTextTurnEvent {
640    fn eq(&self, other: &Self) -> bool {
641        match (self, other) {
642            (
643                Self::Started {
644                    request_id: lhs_request_id,
645                    model: lhs_model,
646                },
647                Self::Started {
648                    request_id: rhs_request_id,
649                    model: rhs_model,
650                },
651            ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
652            (Self::TextDelta { delta: lhs }, Self::TextDelta { delta: rhs }) => lhs == rhs,
653            (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
654                lhs == rhs
655            }
656            (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
657            (
658                Self::ToolCallChunk {
659                    id: lhs_id,
660                    name: lhs_name,
661                    arguments_json_delta: lhs_delta,
662                },
663                Self::ToolCallChunk {
664                    id: rhs_id,
665                    name: rhs_name,
666                    arguments_json_delta: rhs_delta,
667                },
668            ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
669            (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
670            (
671                Self::Completed {
672                    request_id: lhs_request_id,
673                    finish_reason: lhs_finish_reason,
674                    usage: lhs_usage,
675                    ..
676                },
677                Self::Completed {
678                    request_id: rhs_request_id,
679                    finish_reason: rhs_finish_reason,
680                    usage: rhs_usage,
681                    ..
682                },
683            ) => {
684                lhs_request_id == rhs_request_id
685                    && lhs_finish_reason == rhs_finish_reason
686                    && lhs_usage == rhs_usage
687            }
688            _ => false,
689        }
690    }
691}
692
693impl PartialEq for ErasedStructuredTurnEvent {
694    fn eq(&self, other: &Self) -> bool {
695        match (self, other) {
696            (
697                Self::Started {
698                    request_id: lhs_request_id,
699                    model: lhs_model,
700                },
701                Self::Started {
702                    request_id: rhs_request_id,
703                    model: rhs_model,
704                },
705            ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
706            (
707                Self::StructuredOutputChunk { json_delta: lhs },
708                Self::StructuredOutputChunk { json_delta: rhs },
709            ) => lhs == rhs,
710            (Self::StructuredOutputReady(lhs), Self::StructuredOutputReady(rhs)) => lhs == rhs,
711            (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
712                lhs == rhs
713            }
714            (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
715            (
716                Self::ToolCallChunk {
717                    id: lhs_id,
718                    name: lhs_name,
719                    arguments_json_delta: lhs_delta,
720                },
721                Self::ToolCallChunk {
722                    id: rhs_id,
723                    name: rhs_name,
724                    arguments_json_delta: rhs_delta,
725                },
726            ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
727            (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
728            (
729                Self::Completed {
730                    request_id: lhs_request_id,
731                    finish_reason: lhs_finish_reason,
732                    usage: lhs_usage,
733                    ..
734                },
735                Self::Completed {
736                    request_id: rhs_request_id,
737                    finish_reason: rhs_finish_reason,
738                    usage: rhs_usage,
739                    ..
740                },
741            ) => {
742                lhs_request_id == rhs_request_id
743                    && lhs_finish_reason == rhs_finish_reason
744                    && lhs_usage == rhs_usage
745            }
746            _ => false,
747        }
748    }
749}
750
751#[derive(Clone, Debug, Eq, PartialEq)]
752pub enum CompletionEvent {
753    Started {
754        request_id: Option<String>,
755        model: String,
756    },
757    TextDelta(String),
758    Completed {
759        request_id: Option<String>,
760        finish_reason: FinishReason,
761        usage: Usage,
762    },
763}
764
765#[derive(Clone, Debug, PartialEq)]
766pub enum StructuredCompletionEvent<O: StructuredOutput> {
767    Started {
768        request_id: Option<String>,
769        model: String,
770    },
771    StructuredOutputChunk {
772        json_delta: String,
773    },
774    StructuredOutputReady(O),
775    ReasoningDelta {
776        delta: String,
777    },
778    RefusalDelta {
779        delta: String,
780    },
781    Completed {
782        request_id: Option<String>,
783        finish_reason: FinishReason,
784        usage: Usage,
785    },
786}
787
788#[derive(Clone, Debug, Eq, PartialEq)]
789pub enum ErasedStructuredCompletionEvent {
790    Started {
791        request_id: Option<String>,
792        model: String,
793    },
794    StructuredOutputChunk {
795        json_delta: String,
796    },
797    StructuredOutputReady(RawJson),
798    ReasoningDelta {
799        delta: String,
800    },
801    RefusalDelta {
802        delta: String,
803    },
804    Completed {
805        request_id: Option<String>,
806        finish_reason: FinishReason,
807        usage: Usage,
808    },
809}
810
811#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
812pub enum FinishReason {
813    Stop,
814    Length,
815    ToolCall,
816    ContentFilter,
817    Unknown(String),
818}
819
820#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
821pub enum OperationKind {
822    TextTurn,
823    StructuredTurn,
824    StructuredCompletion,
825    Completion,
826}
827
828#[async_trait::async_trait]
829pub trait TurnAdapter: Send + Sync + 'static {
830    async fn text_turn(
831        &self,
832        input: ModelInput,
833        turn: AdapterTextTurn,
834    ) -> Result<ErasedTextTurnEventStream, AgentError>;
835
836    async fn structured_turn(
837        &self,
838        input: ModelInput,
839        turn: AdapterStructuredTurn,
840    ) -> Result<ErasedStructuredTurnEventStream, AgentError>;
841}
842
843#[async_trait::async_trait]
844pub trait CompletionAdapter: Send + Sync + 'static {
845    async fn completion(
846        &self,
847        request: CompletionRequest,
848        extensions: &crate::extensions::RequestExtensions,
849    ) -> Result<CompletionEventStream, AgentError>;
850
851    async fn structured_completion(
852        &self,
853        request: AdapterStructuredCompletionRequest,
854        extensions: &crate::extensions::RequestExtensions,
855    ) -> Result<ErasedStructuredCompletionEventStream, AgentError>;
856}
857
858#[async_trait::async_trait]
859pub trait UsageRecoveryAdapter: Send + Sync + 'static {
860    async fn recover_usage(
861        &self,
862        kind: OperationKind,
863        request_id: &str,
864    ) -> Result<Option<Usage>, AgentError>;
865}
866
867#[async_trait::async_trait]
868impl<T> TurnAdapter for Arc<T>
869where
870    T: TurnAdapter + ?Sized,
871{
872    async fn text_turn(
873        &self,
874        input: ModelInput,
875        turn: AdapterTextTurn,
876    ) -> Result<ErasedTextTurnEventStream, AgentError> {
877        (**self).text_turn(input, turn).await
878    }
879
880    async fn structured_turn(
881        &self,
882        input: ModelInput,
883        turn: AdapterStructuredTurn,
884    ) -> Result<ErasedStructuredTurnEventStream, AgentError> {
885        (**self).structured_turn(input, turn).await
886    }
887}
888
889#[async_trait::async_trait]
890impl<T> CompletionAdapter for Arc<T>
891where
892    T: CompletionAdapter + ?Sized,
893{
894    async fn completion(
895        &self,
896        request: CompletionRequest,
897        extensions: &crate::extensions::RequestExtensions,
898    ) -> Result<CompletionEventStream, AgentError> {
899        (**self).completion(request, extensions).await
900    }
901
902    async fn structured_completion(
903        &self,
904        request: AdapterStructuredCompletionRequest,
905        extensions: &crate::extensions::RequestExtensions,
906    ) -> Result<ErasedStructuredCompletionEventStream, AgentError> {
907        (**self).structured_completion(request, extensions).await
908    }
909}
910
911#[async_trait::async_trait]
912impl<T> UsageRecoveryAdapter for Arc<T>
913where
914    T: UsageRecoveryAdapter + ?Sized,
915{
916    async fn recover_usage(
917        &self,
918        kind: OperationKind,
919        request_id: &str,
920    ) -> Result<Option<Usage>, AgentError> {
921        (**self).recover_usage(kind, request_id).await
922    }
923}
924
925#[test]
926fn test_stream_types_are_send_sync() {
927    fn assert_send_sync<T: Send + Sync>() {}
928    assert_send_sync::<TextTurnEventStream<NoTools>>();
929    assert_send_sync::<StructuredTurnEventStream<NoTools, ()>>();
930    assert_send_sync::<StructuredCompletionEventStream<()>>();
931    assert_send_sync::<CompletionEventStream>();
932    assert_send_sync::<ErasedTextTurnEventStream>();
933    assert_send_sync::<ErasedStructuredTurnEventStream>();
934    assert_send_sync::<ErasedStructuredCompletionEventStream>();
935    assert_send_sync::<TextTurnEvent<NoTools>>();
936    assert_send_sync::<StructuredTurnEvent<NoTools, ()>>();
937    assert_send_sync::<StructuredCompletionEvent<()>>();
938    assert_send_sync::<ErasedTextTurnEvent>();
939    assert_send_sync::<ErasedStructuredTurnEvent>();
940    assert_send_sync::<ErasedStructuredCompletionEvent>();
941    assert_send_sync::<CompletionEvent>();
942}
943
944#[cfg(test)]
945mod tests {
946    use std::sync::Arc;
947
948    use schemars::JsonSchema;
949    use serde::{Deserialize, Serialize};
950
951    use super::*;
952    use crate::transcript::AssistantTurnView;
953
954    #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
955    struct Summary {
956        answer: String,
957    }
958
959    #[test]
960    fn text_completed_equality_checks_committed_turn_identity() {
961        let shared_turn = Arc::new(AssistantTurnView::from_items(&[]));
962        let different_turn = Arc::new(AssistantTurnView::from_items(&[]));
963
964        let lhs = TextTurnEvent::<NoTools>::Completed {
965            request_id: Some("req-1".into()),
966            finish_reason: FinishReason::Stop,
967            usage: Usage::zero(),
968            committed_turn: shared_turn.clone(),
969        };
970        let same = TextTurnEvent::<NoTools>::Completed {
971            request_id: Some("req-1".into()),
972            finish_reason: FinishReason::Stop,
973            usage: Usage::zero(),
974            committed_turn: shared_turn,
975        };
976        let different = TextTurnEvent::<NoTools>::Completed {
977            request_id: Some("req-1".into()),
978            finish_reason: FinishReason::Stop,
979            usage: Usage::zero(),
980            committed_turn: different_turn,
981        };
982
983        assert_eq!(lhs, same);
984        assert_ne!(lhs, different);
985    }
986
987    #[test]
988    fn structured_completed_equality_checks_committed_turn_identity() {
989        let shared_turn = Arc::new(AssistantTurnView::from_items(&[]));
990        let different_turn = Arc::new(AssistantTurnView::from_items(&[]));
991
992        let lhs = StructuredTurnEvent::<NoTools, Summary>::Completed {
993            request_id: Some("req-2".into()),
994            finish_reason: FinishReason::Stop,
995            usage: Usage::zero(),
996            committed_turn: shared_turn.clone(),
997        };
998        let same = StructuredTurnEvent::<NoTools, Summary>::Completed {
999            request_id: Some("req-2".into()),
1000            finish_reason: FinishReason::Stop,
1001            usage: Usage::zero(),
1002            committed_turn: shared_turn,
1003        };
1004        let different = StructuredTurnEvent::<NoTools, Summary>::Completed {
1005            request_id: Some("req-2".into()),
1006            finish_reason: FinishReason::Stop,
1007            usage: Usage::zero(),
1008            committed_turn: different_turn,
1009        };
1010
1011        assert_eq!(lhs, same);
1012        assert_ne!(lhs, different);
1013    }
1014}