Skip to main content

chat_core/
builder.rs

1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::{InputStreamed, Streamed};
9use crate::{
10    chat::{
11        Chat,
12        state::{Embedded, Structured, Unstructured},
13    },
14    traits::CompletionProvider,
15    types::{
16        callback::{CallbackStrategy, RetryStrategy},
17        options::ChatOptions,
18        tools::{ScopedCollection, TypedCollection},
19    },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26    model: Option<CP>,
27    output_shape: Option<schemars::Schema>,
28    model_options: Option<ChatOptions>,
29    max_steps: Option<u16>,
30    max_retries: Option<u16>,
31    retry_strategy: Option<RetryStrategy>,
32    before_strategy: Option<CallbackStrategy>,
33    after_strategy: Option<CallbackStrategy>,
34    scoped_collections: Vec<Box<dyn TypedCollection>>,
35    _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39    pub fn new() -> Self {
40        ChatBuilder {
41            _output: std::marker::PhantomData,
42            ..Default::default()
43        }
44    }
45
46    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47    where
48        T: JsonSchema + DeserializeOwned,
49    {
50        let shape = schemars::schema_for!(T);
51
52        ChatBuilder {
53            model: self.model,
54            max_steps: self.max_steps,
55            max_retries: self.max_retries,
56            retry_strategy: self.retry_strategy,
57            before_strategy: self.before_strategy,
58            after_strategy: self.after_strategy,
59            output_shape: Some(shape),
60            scoped_collections: self.scoped_collections,
61            model_options: self.model_options,
62            _output: std::marker::PhantomData,
63        }
64    }
65
66    #[cfg(feature = "stream")]
67    pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68    where
69        CP: StreamProvider,
70    {
71        if self.output_shape.is_some() {
72            println!(
73                "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74            );
75        }
76
77        ChatBuilder {
78            model: self.model,
79            max_steps: self.max_steps,
80            max_retries: self.max_retries,
81            retry_strategy: self.retry_strategy,
82            before_strategy: self.before_strategy,
83            after_strategy: self.after_strategy,
84            output_shape: None,
85            scoped_collections: self.scoped_collections,
86            model_options: self.model_options,
87            _output: std::marker::PhantomData,
88        }
89    }
90
91    /// Transition into the input-stream type-state. The resulting
92    /// `Chat<CP, InputStreamed>` exposes a `stream(&mut messages)` method
93    /// that returns a [`ChatStream`](crate::chat::input::ChatStream): the
94    /// output stream you iterate with `.next()`, carrying an input side you
95    /// push to with `.send()` (or `split()` into independent handles).
96    /// Audio rides as `PartEnum::File`, text as `PartEnum::Text`, tool
97    /// results as `PartEnum::Tool`, etc. — no parallel input enum, and a
98    /// continuous producer is mapped into `PartEnum` caller-side before it
99    /// reaches `send`.
100    #[cfg(feature = "stream")]
101    pub fn with_input_stream(self) -> ChatBuilder<CP, InputStreamed>
102    where
103        CP: StreamProvider,
104    {
105        if self.output_shape.is_some() {
106            println!(
107                "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
108            );
109        }
110
111        ChatBuilder {
112            model: self.model,
113            max_steps: self.max_steps,
114            max_retries: self.max_retries,
115            retry_strategy: self.retry_strategy,
116            before_strategy: self.before_strategy,
117            after_strategy: self.after_strategy,
118            output_shape: None,
119            scoped_collections: self.scoped_collections,
120            model_options: self.model_options,
121            _output: std::marker::PhantomData,
122        }
123    }
124
125    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
126        if self.output_shape.is_some() {
127            println!(
128                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
129            );
130        }
131
132        ChatBuilder {
133            model: self.model,
134            max_retries: self.max_retries,
135            retry_strategy: self.retry_strategy,
136            before_strategy: self.before_strategy,
137            after_strategy: self.after_strategy,
138            output_shape: None,
139            scoped_collections: Vec::new(),
140            max_steps: None,
141            model_options: self.model_options,
142            _output: std::marker::PhantomData,
143        }
144    }
145}
146
147impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
148    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
149        self.max_steps = Some(max_steps);
150        self
151    }
152
153    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
154        self.max_retries = Some(max_retries);
155        self
156    }
157
158    /// Convenience: wrap a plain `ToolCollection<NoMeta>` with an
159    /// always-execute strategy. Equivalent to
160    /// `.with_scoped_tools(ScopedCollection::auto_execute(tools))`.
161    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
162        self.scoped_collections
163            .push(Box::new(ScopedCollection::auto_execute(tools)));
164        self
165    }
166
167    /// Attach a typed tool collection with a user-defined strategy.
168    /// Multiple calls are additive — the resulting `Chat` routes each
169    /// tool call to the collection that owns its name.
170    pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
171    where
172        M: Send + Sync + 'static,
173        F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
174    {
175        self.scoped_collections.push(Box::new(scoped));
176        self
177    }
178
179    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
180        self.retry_strategy = Some(retry_strategy);
181        self
182    }
183
184    pub fn with_model(mut self, model: CP) -> Self {
185        self.model = Some(model);
186        self
187    }
188
189    pub fn with_options(mut self, options: ChatOptions) -> Self {
190        self.model_options = Some(options);
191        self
192    }
193
194    pub fn build(self) -> Chat<CP, Output> {
195        // Build routing table: tool name → index into scoped_collections.
196        // Collisions across collections are a programming error — we
197        // keep the first and ignore later ones with a warning. (Could
198        // be promoted to a hard error via a builder method later.)
199        let mut routing: HashMap<String, usize> = HashMap::new();
200        for (idx, coll) in self.scoped_collections.iter().enumerate() {
201            for name in coll.names() {
202                if routing.contains_key(name) {
203                    eprintln!(
204                        "chat-rs: tool name `{name}` is registered in multiple scoped \
205                         collections; keeping the first registration."
206                    );
207                    continue;
208                }
209                routing.insert(name.to_string(), idx);
210            }
211        }
212
213        Chat {
214            model: self.model.expect("Need to set a model"),
215            output_shape: self.output_shape,
216            max_steps: self.max_steps,
217            max_retries: self.max_retries,
218            retry_strategy: self.retry_strategy,
219            before_strategy: self.before_strategy,
220            after_strategy: self.after_strategy,
221            scoped_collections: self.scoped_collections,
222            routing,
223            model_options: self.model_options,
224            _output: std::marker::PhantomData,
225        }
226    }
227}
228
229impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
230    fn default() -> Self {
231        ChatBuilder {
232            model: None,
233            output_shape: None,
234            model_options: None,
235            max_steps: None,
236            max_retries: None,
237            retry_strategy: None,
238            before_strategy: None,
239            after_strategy: None,
240            scoped_collections: Vec::new(),
241            _output: std::marker::PhantomData,
242        }
243    }
244}