Skip to main content

chat_core/
builder.rs

1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::{InputStreamed, Streamed};
9use crate::{
10    chat::{
11        Chat,
12        state::{Embedded, Structured, Unstructured},
13    },
14    traits::CompletionProvider,
15    types::{
16        callback::{CallbackStrategy, RetryStrategy},
17        options::ChatOptions,
18        tools::{ScopedCollection, TypedCollection},
19    },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26    model: Option<CP>,
27    output_shape: Option<schemars::Schema>,
28    model_options: Option<ChatOptions>,
29    max_steps: Option<u16>,
30    max_retries: Option<u16>,
31    retry_strategy: Option<RetryStrategy>,
32    before_strategy: Option<CallbackStrategy>,
33    after_strategy: Option<CallbackStrategy>,
34    scoped_collections: Vec<Box<dyn TypedCollection>>,
35    _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39    pub fn new() -> Self {
40        ChatBuilder {
41            _output: std::marker::PhantomData,
42            ..Default::default()
43        }
44    }
45
46    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47    where
48        T: JsonSchema + DeserializeOwned,
49    {
50        let shape = schemars::schema_for!(T);
51
52        ChatBuilder {
53            model: self.model,
54            max_steps: self.max_steps,
55            max_retries: self.max_retries,
56            retry_strategy: self.retry_strategy,
57            before_strategy: self.before_strategy,
58            after_strategy: self.after_strategy,
59            output_shape: Some(shape),
60            scoped_collections: self.scoped_collections,
61            model_options: self.model_options,
62            _output: std::marker::PhantomData,
63        }
64    }
65
66    #[cfg(feature = "stream")]
67    pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68    where
69        CP: StreamProvider,
70    {
71        if self.output_shape.is_some() {
72            println!(
73                "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74            );
75        }
76
77        ChatBuilder {
78            model: self.model,
79            max_steps: self.max_steps,
80            max_retries: self.max_retries,
81            retry_strategy: self.retry_strategy,
82            before_strategy: self.before_strategy,
83            after_strategy: self.after_strategy,
84            output_shape: None,
85            scoped_collections: self.scoped_collections,
86            model_options: self.model_options,
87            _output: std::marker::PhantomData,
88        }
89    }
90
91    /// Transition into the input-stream type-state. The resulting
92    /// `Chat<CP, InputStreamed<I>>` exposes a `stream(&mut messages, input)`
93    /// method that interleaves the model's output stream with a
94    /// caller-supplied `Stream<Item = PartEnum>` input source. Audio
95    /// chunks ride as `PartEnum::File`, text as `PartEnum::Text`,
96    /// tool results as `PartEnum::Tool`, etc. — no parallel input enum.
97    ///
98    /// `I` is just a type-marker at builder time; the actual stream
99    /// instance is passed at `chat.stream(...)` call time, so a single
100    /// `Chat` can be reused across multiple input streams of the same
101    /// type.
102    #[cfg(feature = "stream")]
103    pub fn with_input_stream<I>(self) -> ChatBuilder<CP, InputStreamed<I>>
104    where
105        CP: StreamProvider,
106        I: futures::Stream<Item = crate::types::messages::parts::PartEnum>
107            + Send
108            + Unpin
109            + 'static,
110    {
111        if self.output_shape.is_some() {
112            println!(
113                "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
114            );
115        }
116
117        ChatBuilder {
118            model: self.model,
119            max_steps: self.max_steps,
120            max_retries: self.max_retries,
121            retry_strategy: self.retry_strategy,
122            before_strategy: self.before_strategy,
123            after_strategy: self.after_strategy,
124            output_shape: None,
125            scoped_collections: self.scoped_collections,
126            model_options: self.model_options,
127            _output: std::marker::PhantomData,
128        }
129    }
130
131    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
132        if self.output_shape.is_some() {
133            println!(
134                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
135            );
136        }
137
138        ChatBuilder {
139            model: self.model,
140            max_retries: self.max_retries,
141            retry_strategy: self.retry_strategy,
142            before_strategy: self.before_strategy,
143            after_strategy: self.after_strategy,
144            output_shape: None,
145            scoped_collections: Vec::new(),
146            max_steps: None,
147            model_options: self.model_options,
148            _output: std::marker::PhantomData,
149        }
150    }
151}
152
153impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
154    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
155        self.max_steps = Some(max_steps);
156        self
157    }
158
159    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
160        self.max_retries = Some(max_retries);
161        self
162    }
163
164    /// Convenience: wrap a plain `ToolCollection<NoMeta>` with an
165    /// always-execute strategy. Equivalent to
166    /// `.with_scoped_tools(ScopedCollection::auto_execute(tools))`.
167    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
168        self.scoped_collections
169            .push(Box::new(ScopedCollection::auto_execute(tools)));
170        self
171    }
172
173    /// Attach a typed tool collection with a user-defined strategy.
174    /// Multiple calls are additive — the resulting `Chat` routes each
175    /// tool call to the collection that owns its name.
176    pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
177    where
178        M: Send + Sync + 'static,
179        F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
180    {
181        self.scoped_collections.push(Box::new(scoped));
182        self
183    }
184
185    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
186        self.retry_strategy = Some(retry_strategy);
187        self
188    }
189
190    pub fn with_model(mut self, model: CP) -> Self {
191        self.model = Some(model);
192        self
193    }
194
195    pub fn with_options(mut self, options: ChatOptions) -> Self {
196        self.model_options = Some(options);
197        self
198    }
199
200    pub fn build(self) -> Chat<CP, Output> {
201        // Build routing table: tool name → index into scoped_collections.
202        // Collisions across collections are a programming error — we
203        // keep the first and ignore later ones with a warning. (Could
204        // be promoted to a hard error via a builder method later.)
205        let mut routing: HashMap<String, usize> = HashMap::new();
206        for (idx, coll) in self.scoped_collections.iter().enumerate() {
207            for name in coll.names() {
208                if routing.contains_key(name) {
209                    eprintln!(
210                        "chat-rs: tool name `{name}` is registered in multiple scoped \
211                         collections; keeping the first registration."
212                    );
213                    continue;
214                }
215                routing.insert(name.to_string(), idx);
216            }
217        }
218
219        Chat {
220            model: self.model.expect("Need to set a model"),
221            output_shape: self.output_shape,
222            max_steps: self.max_steps,
223            max_retries: self.max_retries,
224            retry_strategy: self.retry_strategy,
225            before_strategy: self.before_strategy,
226            after_strategy: self.after_strategy,
227            scoped_collections: self.scoped_collections,
228            routing,
229            model_options: self.model_options,
230            _output: std::marker::PhantomData,
231        }
232    }
233}
234
235impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
236    fn default() -> Self {
237        ChatBuilder {
238            model: None,
239            output_shape: None,
240            model_options: None,
241            max_steps: None,
242            max_retries: None,
243            retry_strategy: None,
244            before_strategy: None,
245            after_strategy: None,
246            scoped_collections: Vec::new(),
247            _output: std::marker::PhantomData,
248        }
249    }
250}