1use std::{borrow::Cow, fmt, marker::PhantomData, pin::Pin, sync::Arc};
2
3use bon::Builder;
4use futures::Stream;
5
6use crate::{
7 AgentError,
8 budget::{RequestBudget, Usage},
9 conversation::{ModelInput, RawJson, ToolMetadata},
10 structured::StructuredOutput,
11 toolset::{NoTools, ToolPolicy, Toolset},
12 transcript::CommittedTurn,
13};
14
15pub type TextTurnEventStream<T, E = AgentError> =
16 Pin<Box<dyn Stream<Item = Result<TextTurnEvent<T>, E>> + Send + Sync + 'static>>;
17pub type StructuredTurnEventStream<T, O, E = AgentError> =
18 Pin<Box<dyn Stream<Item = Result<StructuredTurnEvent<T, O>, E>> + Send + Sync + 'static>>;
19pub type StructuredCompletionEventStream<O, E = AgentError> =
20 Pin<Box<dyn Stream<Item = Result<StructuredCompletionEvent<O>, E>> + Send + Sync + 'static>>;
21pub type CompletionEventStream<E = AgentError> =
22 Pin<Box<dyn Stream<Item = Result<CompletionEvent, E>> + Send + Sync + 'static>>;
23pub type ErasedTextTurnEventStream<E = AgentError> =
24 Pin<Box<dyn Stream<Item = Result<ErasedTextTurnEvent, E>> + Send + Sync + 'static>>;
25pub type ErasedStructuredTurnEventStream<E = AgentError> =
26 Pin<Box<dyn Stream<Item = Result<ErasedStructuredTurnEvent, E>> + Send + Sync + 'static>>;
27pub type ErasedStructuredCompletionEventStream<E = AgentError> =
28 Pin<Box<dyn Stream<Item = Result<ErasedStructuredCompletionEvent, E>> + Send + Sync + 'static>>;
29
30#[derive(
31 Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, serde::Serialize, serde::Deserialize,
32)]
33#[serde(try_from = "String", into = "String")]
34pub struct ModelName(String);
35
36impl ModelName {
37 pub fn new(model: impl Into<String>) -> Result<Self, ModelNameError> {
38 let model = model.into();
39 if model.trim().is_empty() {
40 return Err(ModelNameError::Empty);
41 }
42 Ok(Self(model))
43 }
44
45 pub fn as_str(&self) -> &str {
46 &self.0
47 }
48
49 pub fn into_string(self) -> String {
50 self.0
51 }
52}
53
54impl AsRef<str> for ModelName {
55 fn as_ref(&self) -> &str {
56 self.as_str()
57 }
58}
59
60impl fmt::Display for ModelName {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.write_str(self.as_str())
63 }
64}
65
66impl From<ModelName> for String {
67 fn from(value: ModelName) -> Self {
68 value.into_string()
69 }
70}
71
72impl TryFrom<String> for ModelName {
73 type Error = ModelNameError;
74
75 fn try_from(value: String) -> Result<Self, Self::Error> {
76 Self::new(value)
77 }
78}
79
80impl TryFrom<&str> for ModelName {
81 type Error = ModelNameError;
82
83 fn try_from(value: &str) -> Result<Self, Self::Error> {
84 Self::new(value)
85 }
86}
87
88#[derive(Clone, Debug, thiserror::Error, Eq, PartialEq)]
89pub enum ModelNameError {
90 #[error("model must not be empty")]
91 Empty,
92}
93
94#[derive(Builder, Clone, Debug, Default, PartialEq)]
95#[builder(builder_type(name = GenerationParamsBuilder))]
96pub struct GenerationParams {
97 pub temperature: Option<Temperature>,
98 pub max_output_tokens: Option<u32>,
99 pub seed: Option<u64>,
100}
101
102#[derive(Builder, Clone, Debug, PartialEq)]
103#[builder(builder_type(name = TurnConfigBuilder))]
104pub struct TurnConfig<T: Toolset = NoTools> {
105 pub model: ModelName,
106 #[builder(default)]
107 pub generation: GenerationParams,
108 #[builder(default)]
109 pub tools: ToolPolicy<T>,
110 #[builder(default = RequestBudget::unlimited())]
111 pub budget: RequestBudget,
112}
113
114impl<T> TurnConfig<T>
115where
116 T: Toolset,
117{
118 pub fn new(model: ModelName) -> Self {
119 Self {
120 model,
121 generation: GenerationParams::default(),
122 tools: ToolPolicy::Disabled,
123 budget: RequestBudget::unlimited(),
124 }
125 }
126}
127
128#[derive(Builder, Clone, Debug, PartialEq)]
129#[builder(builder_type(name = TextTurnBuilder))]
130pub struct TextTurn<T: Toolset = NoTools> {
131 pub config: TurnConfig<T>,
132}
133
134impl<T> TextTurn<T>
135where
136 T: Toolset,
137{
138 pub fn new(model: ModelName) -> Self {
139 Self {
140 config: TurnConfig::new(model),
141 }
142 }
143}
144
145#[derive(Builder, Clone, Debug, PartialEq)]
146#[builder(builder_type(name = StructuredOutputSpecBuilder))]
147pub struct StructuredOutputSpec<O: StructuredOutput> {
148 #[builder(skip = PhantomData)]
149 _marker: PhantomData<fn() -> O>,
150}
151
152impl<O> Default for StructuredOutputSpec<O>
153where
154 O: StructuredOutput,
155{
156 fn default() -> Self {
157 Self {
158 _marker: PhantomData,
159 }
160 }
161}
162
163#[derive(Builder, Clone, Debug, PartialEq)]
164#[builder(builder_type(name = StructuredTurnBuilder))]
165pub struct StructuredTurn<T: Toolset, O: StructuredOutput> {
166 pub config: TurnConfig<T>,
167 #[builder(default)]
168 pub output: StructuredOutputSpec<O>,
169}
170
171impl<T, O> StructuredTurn<T, O>
172where
173 T: Toolset,
174 O: StructuredOutput,
175{
176 pub fn new(model: ModelName) -> Self {
177 Self {
178 config: TurnConfig::new(model),
179 output: StructuredOutputSpec::default(),
180 }
181 }
182}
183
184#[derive(Builder, Clone, Debug, PartialEq)]
185#[builder(builder_type(name = StructuredCompletionRequestBuilder))]
186pub struct StructuredCompletionRequest<O: StructuredOutput> {
187 pub model: ModelName,
188 pub system: Option<String>,
189 #[builder(into)]
190 pub prompt: String,
191 #[builder(default)]
192 pub generation: GenerationParams,
193 #[builder(default = RequestBudget::unlimited())]
194 pub budget: RequestBudget,
195 #[builder(default)]
196 pub output: StructuredOutputSpec<O>,
197}
198
199impl<O> StructuredCompletionRequest<O>
200where
201 O: StructuredOutput,
202{
203 pub fn new(model: ModelName, prompt: impl Into<String>) -> Self {
204 Self {
205 model,
206 system: None,
207 prompt: prompt.into(),
208 generation: GenerationParams::default(),
209 budget: RequestBudget::unlimited(),
210 output: StructuredOutputSpec::default(),
211 }
212 }
213
214 pub fn with_system(mut self, system: impl Into<String>) -> Self {
215 self.system = Some(system.into());
216 self
217 }
218}
219
220#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct AdapterToolDefinition {
222 pub name: String,
223 pub description: String,
224 pub input_schema: serde_json::Value,
225}
226
227#[derive(Clone, Debug, PartialEq, Eq)]
228pub enum AdapterToolChoice {
229 None,
230 Auto,
231 Required,
232 Specific(String),
233}
234
235#[derive(Clone, Debug, PartialEq)]
236pub struct AdapterTurnConfig {
237 pub model: ModelName,
238 pub generation: GenerationParams,
239 pub tools: Vec<AdapterToolDefinition>,
240 pub tool_choice: AdapterToolChoice,
241}
242
243#[derive(Clone, Debug, Default, Eq, PartialEq)]
244pub struct ModelSelection {
245 pub primary: Option<Cow<'static, str>>,
248 pub fallbacks: Option<Vec<Cow<'static, str>>>,
251}
252
253pub trait ModelSelector: Send + Sync {
254 fn select_model(&self, extensions: &crate::extensions::RequestExtensions) -> ModelSelection;
255}
256
257#[derive(Clone)]
258pub struct AdapterTextTurn {
259 pub config: AdapterTurnConfig,
260 pub extensions: Arc<crate::extensions::RequestExtensions>,
261}
262
263#[derive(Clone, Debug, PartialEq)]
264pub struct AdapterStructuredOutputSpec {
265 pub schema_name: String,
266 pub schema: serde_json::Value,
267}
268
269#[derive(Clone)]
270pub struct AdapterStructuredTurn {
271 pub config: AdapterTurnConfig,
272 pub extensions: Arc<crate::extensions::RequestExtensions>,
273 pub output: AdapterStructuredOutputSpec,
274}
275
276#[derive(Clone, Debug, PartialEq)]
277pub struct AdapterStructuredCompletionRequest {
278 pub model: ModelName,
279 pub system: Option<String>,
280 pub prompt: String,
281 pub generation: GenerationParams,
282 pub output: AdapterStructuredOutputSpec,
283}
284
285impl fmt::Debug for AdapterTextTurn {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 f.debug_struct("AdapterTextTurn")
288 .field("config", &self.config)
289 .field("extensions", &"<opaque>")
290 .finish()
291 }
292}
293
294impl PartialEq for AdapterTextTurn {
295 fn eq(&self, other: &Self) -> bool {
296 self.config == other.config && Arc::ptr_eq(&self.extensions, &other.extensions)
297 }
298}
299
300impl fmt::Debug for AdapterStructuredTurn {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 f.debug_struct("AdapterStructuredTurn")
303 .field("config", &self.config)
304 .field("extensions", &"<opaque>")
305 .field("output", &self.output)
306 .finish()
307 }
308}
309
310impl PartialEq for AdapterStructuredTurn {
311 fn eq(&self, other: &Self) -> bool {
312 self.config == other.config
313 && Arc::ptr_eq(&self.extensions, &other.extensions)
314 && self.output == other.output
315 }
316}
317
318#[derive(Clone, Copy, Debug, PartialEq)]
319pub struct Temperature(f32);
320
321impl Temperature {
322 pub const MIN: f32 = 0.0;
323 pub const MAX: f32 = 2.0;
324
325 pub fn new(value: f32) -> Result<Self, TemperatureError> {
326 if !value.is_finite() {
327 return Err(TemperatureError::NonFinite);
328 }
329 if !(Self::MIN..=Self::MAX).contains(&value) {
330 return Err(TemperatureError::OutOfRange { value });
331 }
332 Ok(Self(value))
333 }
334
335 pub fn get(self) -> f32 {
336 self.0
337 }
338}
339
340impl TryFrom<f32> for Temperature {
341 type Error = TemperatureError;
342
343 fn try_from(value: f32) -> Result<Self, Self::Error> {
344 Self::new(value)
345 }
346}
347
348#[derive(Clone, Debug, thiserror::Error, PartialEq)]
349pub enum TemperatureError {
350 #[error("temperature must be finite")]
351 NonFinite,
352 #[error("temperature {value} must be in the range [0.0, 2.0]")]
353 OutOfRange { value: f32 },
354}
355
356#[derive(Builder, Clone, Debug, PartialEq)]
357#[builder(builder_type(name = CompletionRequestBuilder))]
358pub struct CompletionRequest {
359 pub model: ModelName,
360 #[builder(into)]
361 pub prompt: String,
362 #[builder(default)]
363 pub options: CompletionOptions,
364 #[builder(default = RequestBudget::unlimited())]
365 pub budget: RequestBudget,
366}
367
368impl CompletionRequest {
369 pub fn new(model: ModelName, prompt: impl Into<String>) -> Self {
370 Self {
371 model,
372 prompt: prompt.into(),
373 options: CompletionOptions::default(),
374 budget: RequestBudget::unlimited(),
375 }
376 }
377
378 pub fn with_options(mut self, options: CompletionOptions) -> Self {
379 self.options = options;
380 self
381 }
382
383 pub fn with_budget(mut self, budget: RequestBudget) -> Self {
384 self.budget = budget;
385 self
386 }
387}
388
389#[derive(Builder, Clone, Debug, Default, PartialEq)]
390#[builder(builder_type(name = CompletionOptionsBuilder))]
391pub struct CompletionOptions {
392 pub temperature: Option<Temperature>,
393 pub max_output_tokens: Option<u32>,
394 #[builder(default)]
395 pub stop: Vec<String>,
396}
397
398#[derive(Clone, Debug)]
399pub enum TextTurnEvent<T: Toolset> {
400 Started {
401 request_id: Option<String>,
402 model: String,
403 },
404 TextDelta {
405 delta: String,
406 },
407 ReasoningDelta {
408 delta: String,
409 },
410 RefusalDelta {
411 delta: String,
412 },
413 ToolCallChunk {
414 id: crate::conversation::ToolCallId,
415 name: crate::conversation::ToolName,
416 arguments_json_delta: String,
417 },
418 ToolCallReady(T::ToolCall),
419 Completed {
420 request_id: Option<String>,
421 finish_reason: FinishReason,
422 usage: Usage,
423 committed_turn: CommittedTurn,
424 },
425}
426
427#[derive(Clone, Debug)]
428pub enum StructuredTurnEvent<T: Toolset, O: StructuredOutput> {
429 Started {
430 request_id: Option<String>,
431 model: String,
432 },
433 StructuredOutputChunk {
434 json_delta: String,
435 },
436 StructuredOutputReady(O),
437 ReasoningDelta {
438 delta: String,
439 },
440 RefusalDelta {
441 delta: String,
442 },
443 ToolCallChunk {
444 id: crate::conversation::ToolCallId,
445 name: crate::conversation::ToolName,
446 arguments_json_delta: String,
447 },
448 ToolCallReady(T::ToolCall),
449 Completed {
450 request_id: Option<String>,
451 finish_reason: FinishReason,
452 usage: Usage,
453 committed_turn: CommittedTurn,
454 },
455}
456
457#[derive(Clone, Debug)]
458pub enum ErasedTextTurnEvent {
459 Started {
460 request_id: Option<String>,
461 model: String,
462 },
463 TextDelta {
464 delta: String,
465 },
466 ReasoningDelta {
467 delta: String,
468 },
469 RefusalDelta {
470 delta: String,
471 },
472 ToolCallChunk {
473 id: crate::conversation::ToolCallId,
474 name: crate::conversation::ToolName,
475 arguments_json_delta: String,
476 },
477 ToolCallReady(ToolMetadata),
478 Completed {
479 request_id: Option<String>,
480 finish_reason: FinishReason,
481 usage: Usage,
482 committed_turn: CommittedTurn,
483 },
484}
485
486#[derive(Clone, Debug)]
487pub enum ErasedStructuredTurnEvent {
488 Started {
489 request_id: Option<String>,
490 model: String,
491 },
492 StructuredOutputChunk {
493 json_delta: String,
494 },
495 StructuredOutputReady(RawJson),
496 ReasoningDelta {
497 delta: String,
498 },
499 RefusalDelta {
500 delta: String,
501 },
502 ToolCallChunk {
503 id: crate::conversation::ToolCallId,
504 name: crate::conversation::ToolName,
505 arguments_json_delta: String,
506 },
507 ToolCallReady(ToolMetadata),
508 Completed {
509 request_id: Option<String>,
510 finish_reason: FinishReason,
511 usage: Usage,
512 committed_turn: CommittedTurn,
513 },
514}
515
516impl<T> PartialEq for TextTurnEvent<T>
517where
518 T: Toolset,
519 T::ToolCall: PartialEq,
520{
521 fn eq(&self, other: &Self) -> bool {
522 match (self, other) {
523 (
524 Self::Started {
525 request_id: lhs_request_id,
526 model: lhs_model,
527 },
528 Self::Started {
529 request_id: rhs_request_id,
530 model: rhs_model,
531 },
532 ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
533 (Self::TextDelta { delta: lhs }, Self::TextDelta { delta: rhs }) => lhs == rhs,
534 (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
535 lhs == rhs
536 }
537 (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
538 (
539 Self::ToolCallChunk {
540 id: lhs_id,
541 name: lhs_name,
542 arguments_json_delta: lhs_delta,
543 },
544 Self::ToolCallChunk {
545 id: rhs_id,
546 name: rhs_name,
547 arguments_json_delta: rhs_delta,
548 },
549 ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
550 (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
551 (
552 Self::Completed {
553 request_id: lhs_request_id,
554 finish_reason: lhs_finish_reason,
555 usage: lhs_usage,
556 committed_turn: lhs_committed_turn,
557 },
558 Self::Completed {
559 request_id: rhs_request_id,
560 finish_reason: rhs_finish_reason,
561 usage: rhs_usage,
562 committed_turn: rhs_committed_turn,
563 },
564 ) => {
565 lhs_request_id == rhs_request_id
566 && lhs_finish_reason == rhs_finish_reason
567 && lhs_usage == rhs_usage
568 && Arc::ptr_eq(lhs_committed_turn, rhs_committed_turn)
569 }
570 _ => false,
571 }
572 }
573}
574
575impl<T, O> PartialEq for StructuredTurnEvent<T, O>
576where
577 T: Toolset,
578 T::ToolCall: PartialEq,
579 O: StructuredOutput + PartialEq,
580{
581 fn eq(&self, other: &Self) -> bool {
582 match (self, other) {
583 (
584 Self::Started {
585 request_id: lhs_request_id,
586 model: lhs_model,
587 },
588 Self::Started {
589 request_id: rhs_request_id,
590 model: rhs_model,
591 },
592 ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
593 (
594 Self::StructuredOutputChunk { json_delta: lhs },
595 Self::StructuredOutputChunk { json_delta: rhs },
596 ) => lhs == rhs,
597 (Self::StructuredOutputReady(lhs), Self::StructuredOutputReady(rhs)) => lhs == rhs,
598 (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
599 lhs == rhs
600 }
601 (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
602 (
603 Self::ToolCallChunk {
604 id: lhs_id,
605 name: lhs_name,
606 arguments_json_delta: lhs_delta,
607 },
608 Self::ToolCallChunk {
609 id: rhs_id,
610 name: rhs_name,
611 arguments_json_delta: rhs_delta,
612 },
613 ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
614 (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
615 (
616 Self::Completed {
617 request_id: lhs_request_id,
618 finish_reason: lhs_finish_reason,
619 usage: lhs_usage,
620 committed_turn: lhs_committed_turn,
621 },
622 Self::Completed {
623 request_id: rhs_request_id,
624 finish_reason: rhs_finish_reason,
625 usage: rhs_usage,
626 committed_turn: rhs_committed_turn,
627 },
628 ) => {
629 lhs_request_id == rhs_request_id
630 && lhs_finish_reason == rhs_finish_reason
631 && lhs_usage == rhs_usage
632 && Arc::ptr_eq(lhs_committed_turn, rhs_committed_turn)
633 }
634 _ => false,
635 }
636 }
637}
638
639impl PartialEq for ErasedTextTurnEvent {
640 fn eq(&self, other: &Self) -> bool {
641 match (self, other) {
642 (
643 Self::Started {
644 request_id: lhs_request_id,
645 model: lhs_model,
646 },
647 Self::Started {
648 request_id: rhs_request_id,
649 model: rhs_model,
650 },
651 ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
652 (Self::TextDelta { delta: lhs }, Self::TextDelta { delta: rhs }) => lhs == rhs,
653 (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
654 lhs == rhs
655 }
656 (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
657 (
658 Self::ToolCallChunk {
659 id: lhs_id,
660 name: lhs_name,
661 arguments_json_delta: lhs_delta,
662 },
663 Self::ToolCallChunk {
664 id: rhs_id,
665 name: rhs_name,
666 arguments_json_delta: rhs_delta,
667 },
668 ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
669 (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
670 (
671 Self::Completed {
672 request_id: lhs_request_id,
673 finish_reason: lhs_finish_reason,
674 usage: lhs_usage,
675 ..
676 },
677 Self::Completed {
678 request_id: rhs_request_id,
679 finish_reason: rhs_finish_reason,
680 usage: rhs_usage,
681 ..
682 },
683 ) => {
684 lhs_request_id == rhs_request_id
685 && lhs_finish_reason == rhs_finish_reason
686 && lhs_usage == rhs_usage
687 }
688 _ => false,
689 }
690 }
691}
692
693impl PartialEq for ErasedStructuredTurnEvent {
694 fn eq(&self, other: &Self) -> bool {
695 match (self, other) {
696 (
697 Self::Started {
698 request_id: lhs_request_id,
699 model: lhs_model,
700 },
701 Self::Started {
702 request_id: rhs_request_id,
703 model: rhs_model,
704 },
705 ) => lhs_request_id == rhs_request_id && lhs_model == rhs_model,
706 (
707 Self::StructuredOutputChunk { json_delta: lhs },
708 Self::StructuredOutputChunk { json_delta: rhs },
709 ) => lhs == rhs,
710 (Self::StructuredOutputReady(lhs), Self::StructuredOutputReady(rhs)) => lhs == rhs,
711 (Self::ReasoningDelta { delta: lhs }, Self::ReasoningDelta { delta: rhs }) => {
712 lhs == rhs
713 }
714 (Self::RefusalDelta { delta: lhs }, Self::RefusalDelta { delta: rhs }) => lhs == rhs,
715 (
716 Self::ToolCallChunk {
717 id: lhs_id,
718 name: lhs_name,
719 arguments_json_delta: lhs_delta,
720 },
721 Self::ToolCallChunk {
722 id: rhs_id,
723 name: rhs_name,
724 arguments_json_delta: rhs_delta,
725 },
726 ) => lhs_id == rhs_id && lhs_name == rhs_name && lhs_delta == rhs_delta,
727 (Self::ToolCallReady(lhs), Self::ToolCallReady(rhs)) => lhs == rhs,
728 (
729 Self::Completed {
730 request_id: lhs_request_id,
731 finish_reason: lhs_finish_reason,
732 usage: lhs_usage,
733 ..
734 },
735 Self::Completed {
736 request_id: rhs_request_id,
737 finish_reason: rhs_finish_reason,
738 usage: rhs_usage,
739 ..
740 },
741 ) => {
742 lhs_request_id == rhs_request_id
743 && lhs_finish_reason == rhs_finish_reason
744 && lhs_usage == rhs_usage
745 }
746 _ => false,
747 }
748 }
749}
750
751#[derive(Clone, Debug, Eq, PartialEq)]
752pub enum CompletionEvent {
753 Started {
754 request_id: Option<String>,
755 model: String,
756 },
757 TextDelta(String),
758 Completed {
759 request_id: Option<String>,
760 finish_reason: FinishReason,
761 usage: Usage,
762 },
763}
764
765#[derive(Clone, Debug, PartialEq)]
766pub enum StructuredCompletionEvent<O: StructuredOutput> {
767 Started {
768 request_id: Option<String>,
769 model: String,
770 },
771 StructuredOutputChunk {
772 json_delta: String,
773 },
774 StructuredOutputReady(O),
775 ReasoningDelta {
776 delta: String,
777 },
778 RefusalDelta {
779 delta: String,
780 },
781 Completed {
782 request_id: Option<String>,
783 finish_reason: FinishReason,
784 usage: Usage,
785 },
786}
787
788#[derive(Clone, Debug, Eq, PartialEq)]
789pub enum ErasedStructuredCompletionEvent {
790 Started {
791 request_id: Option<String>,
792 model: String,
793 },
794 StructuredOutputChunk {
795 json_delta: String,
796 },
797 StructuredOutputReady(RawJson),
798 ReasoningDelta {
799 delta: String,
800 },
801 RefusalDelta {
802 delta: String,
803 },
804 Completed {
805 request_id: Option<String>,
806 finish_reason: FinishReason,
807 usage: Usage,
808 },
809}
810
811#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
812pub enum FinishReason {
813 Stop,
814 Length,
815 ToolCall,
816 ContentFilter,
817 Unknown(String),
818}
819
820#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
821pub enum OperationKind {
822 TextTurn,
823 StructuredTurn,
824 StructuredCompletion,
825 Completion,
826}
827
828#[async_trait::async_trait]
829pub trait TurnAdapter: Send + Sync + 'static {
830 async fn text_turn(
831 &self,
832 input: ModelInput,
833 turn: AdapterTextTurn,
834 ) -> Result<ErasedTextTurnEventStream, AgentError>;
835
836 async fn structured_turn(
837 &self,
838 input: ModelInput,
839 turn: AdapterStructuredTurn,
840 ) -> Result<ErasedStructuredTurnEventStream, AgentError>;
841}
842
843#[async_trait::async_trait]
844pub trait CompletionAdapter: Send + Sync + 'static {
845 async fn completion(
846 &self,
847 request: CompletionRequest,
848 extensions: &crate::extensions::RequestExtensions,
849 ) -> Result<CompletionEventStream, AgentError>;
850
851 async fn structured_completion(
852 &self,
853 request: AdapterStructuredCompletionRequest,
854 extensions: &crate::extensions::RequestExtensions,
855 ) -> Result<ErasedStructuredCompletionEventStream, AgentError>;
856}
857
858#[async_trait::async_trait]
859pub trait UsageRecoveryAdapter: Send + Sync + 'static {
860 async fn recover_usage(
861 &self,
862 kind: OperationKind,
863 request_id: &str,
864 ) -> Result<Option<Usage>, AgentError>;
865}
866
867#[async_trait::async_trait]
868impl<T> TurnAdapter for Arc<T>
869where
870 T: TurnAdapter + ?Sized,
871{
872 async fn text_turn(
873 &self,
874 input: ModelInput,
875 turn: AdapterTextTurn,
876 ) -> Result<ErasedTextTurnEventStream, AgentError> {
877 (**self).text_turn(input, turn).await
878 }
879
880 async fn structured_turn(
881 &self,
882 input: ModelInput,
883 turn: AdapterStructuredTurn,
884 ) -> Result<ErasedStructuredTurnEventStream, AgentError> {
885 (**self).structured_turn(input, turn).await
886 }
887}
888
889#[async_trait::async_trait]
890impl<T> CompletionAdapter for Arc<T>
891where
892 T: CompletionAdapter + ?Sized,
893{
894 async fn completion(
895 &self,
896 request: CompletionRequest,
897 extensions: &crate::extensions::RequestExtensions,
898 ) -> Result<CompletionEventStream, AgentError> {
899 (**self).completion(request, extensions).await
900 }
901
902 async fn structured_completion(
903 &self,
904 request: AdapterStructuredCompletionRequest,
905 extensions: &crate::extensions::RequestExtensions,
906 ) -> Result<ErasedStructuredCompletionEventStream, AgentError> {
907 (**self).structured_completion(request, extensions).await
908 }
909}
910
911#[async_trait::async_trait]
912impl<T> UsageRecoveryAdapter for Arc<T>
913where
914 T: UsageRecoveryAdapter + ?Sized,
915{
916 async fn recover_usage(
917 &self,
918 kind: OperationKind,
919 request_id: &str,
920 ) -> Result<Option<Usage>, AgentError> {
921 (**self).recover_usage(kind, request_id).await
922 }
923}
924
925#[test]
926fn test_stream_types_are_send_sync() {
927 fn assert_send_sync<T: Send + Sync>() {}
928 assert_send_sync::<TextTurnEventStream<NoTools>>();
929 assert_send_sync::<StructuredTurnEventStream<NoTools, ()>>();
930 assert_send_sync::<StructuredCompletionEventStream<()>>();
931 assert_send_sync::<CompletionEventStream>();
932 assert_send_sync::<ErasedTextTurnEventStream>();
933 assert_send_sync::<ErasedStructuredTurnEventStream>();
934 assert_send_sync::<ErasedStructuredCompletionEventStream>();
935 assert_send_sync::<TextTurnEvent<NoTools>>();
936 assert_send_sync::<StructuredTurnEvent<NoTools, ()>>();
937 assert_send_sync::<StructuredCompletionEvent<()>>();
938 assert_send_sync::<ErasedTextTurnEvent>();
939 assert_send_sync::<ErasedStructuredTurnEvent>();
940 assert_send_sync::<ErasedStructuredCompletionEvent>();
941 assert_send_sync::<CompletionEvent>();
942}
943
944#[cfg(test)]
945mod tests {
946 use std::sync::Arc;
947
948 use schemars::JsonSchema;
949 use serde::{Deserialize, Serialize};
950
951 use super::*;
952 use crate::transcript::AssistantTurnView;
953
954 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
955 struct Summary {
956 answer: String,
957 }
958
959 #[test]
960 fn text_completed_equality_checks_committed_turn_identity() {
961 let shared_turn = Arc::new(AssistantTurnView::from_items(&[]));
962 let different_turn = Arc::new(AssistantTurnView::from_items(&[]));
963
964 let lhs = TextTurnEvent::<NoTools>::Completed {
965 request_id: Some("req-1".into()),
966 finish_reason: FinishReason::Stop,
967 usage: Usage::zero(),
968 committed_turn: shared_turn.clone(),
969 };
970 let same = TextTurnEvent::<NoTools>::Completed {
971 request_id: Some("req-1".into()),
972 finish_reason: FinishReason::Stop,
973 usage: Usage::zero(),
974 committed_turn: shared_turn,
975 };
976 let different = TextTurnEvent::<NoTools>::Completed {
977 request_id: Some("req-1".into()),
978 finish_reason: FinishReason::Stop,
979 usage: Usage::zero(),
980 committed_turn: different_turn,
981 };
982
983 assert_eq!(lhs, same);
984 assert_ne!(lhs, different);
985 }
986
987 #[test]
988 fn structured_completed_equality_checks_committed_turn_identity() {
989 let shared_turn = Arc::new(AssistantTurnView::from_items(&[]));
990 let different_turn = Arc::new(AssistantTurnView::from_items(&[]));
991
992 let lhs = StructuredTurnEvent::<NoTools, Summary>::Completed {
993 request_id: Some("req-2".into()),
994 finish_reason: FinishReason::Stop,
995 usage: Usage::zero(),
996 committed_turn: shared_turn.clone(),
997 };
998 let same = StructuredTurnEvent::<NoTools, Summary>::Completed {
999 request_id: Some("req-2".into()),
1000 finish_reason: FinishReason::Stop,
1001 usage: Usage::zero(),
1002 committed_turn: shared_turn,
1003 };
1004 let different = StructuredTurnEvent::<NoTools, Summary>::Completed {
1005 request_id: Some("req-2".into()),
1006 finish_reason: FinishReason::Stop,
1007 usage: Usage::zero(),
1008 committed_turn: different_turn,
1009 };
1010
1011 assert_eq!(lhs, same);
1012 assert_ne!(lhs, different);
1013 }
1014}