Skip to main content

chat_core/
builder.rs

1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::InputStreamed;
9use crate::{
10    chat::{
11        Chat,
12        state::{Embedded, Structured, Unstructured},
13    },
14    traits::CompletionProvider,
15    types::{
16        callback::{CallbackStrategy, RetryStrategy},
17        options::ChatOptions,
18        tools::{ScopedCollection, TypedCollection},
19    },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26    model: Option<CP>,
27    output_shape: Option<schemars::Schema>,
28    model_options: Option<ChatOptions>,
29    max_steps: Option<u16>,
30    max_retries: Option<u16>,
31    retry_strategy: Option<RetryStrategy>,
32    before_strategy: Option<CallbackStrategy>,
33    after_strategy: Option<CallbackStrategy>,
34    scoped_collections: Vec<Box<dyn TypedCollection>>,
35    _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39    pub fn new() -> Self {
40        ChatBuilder {
41            _output: std::marker::PhantomData,
42            ..Default::default()
43        }
44    }
45
46    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47    where
48        T: JsonSchema + DeserializeOwned,
49    {
50        let shape = schemars::schema_for!(T);
51
52        ChatBuilder {
53            model: self.model,
54            max_steps: self.max_steps,
55            max_retries: self.max_retries,
56            retry_strategy: self.retry_strategy,
57            before_strategy: self.before_strategy,
58            after_strategy: self.after_strategy,
59            output_shape: Some(shape),
60            scoped_collections: self.scoped_collections,
61            model_options: self.model_options,
62            _output: std::marker::PhantomData,
63        }
64    }
65
66    /// Transition into the input-stream type-state. The resulting
67    /// `Chat<CP, InputStreamed>` exposes a `stream(&mut messages)` method
68    /// that returns a [`ChatStream`](crate::chat::input::ChatStream): the
69    /// output stream you iterate with `.next()`, carrying an input side you
70    /// push to with `.send()` (or `split()` into independent handles).
71    /// Audio rides as `PartEnum::File`, text as `PartEnum::Text`, tool
72    /// results as `PartEnum::Tool`, etc. — no parallel input enum, and a
73    /// continuous producer is mapped into `PartEnum` caller-side before it
74    /// reaches `send`.
75    #[cfg(feature = "stream")]
76    pub fn with_input_stream(self) -> ChatBuilder<CP, InputStreamed>
77    where
78        CP: StreamProvider,
79    {
80        if self.output_shape.is_some() {
81            println!(
82                "Warning: Cannot call input-streamed responses with structured outputs. Output shape will be set to None"
83            );
84        }
85
86        ChatBuilder {
87            model: self.model,
88            max_steps: self.max_steps,
89            max_retries: self.max_retries,
90            retry_strategy: self.retry_strategy,
91            before_strategy: self.before_strategy,
92            after_strategy: self.after_strategy,
93            output_shape: None,
94            scoped_collections: self.scoped_collections,
95            model_options: self.model_options,
96            _output: std::marker::PhantomData,
97        }
98    }
99
100    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
101        if self.output_shape.is_some() {
102            println!(
103                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
104            );
105        }
106
107        ChatBuilder {
108            model: self.model,
109            max_retries: self.max_retries,
110            retry_strategy: self.retry_strategy,
111            before_strategy: self.before_strategy,
112            after_strategy: self.after_strategy,
113            output_shape: None,
114            scoped_collections: Vec::new(),
115            max_steps: None,
116            model_options: self.model_options,
117            _output: std::marker::PhantomData,
118        }
119    }
120}
121
122impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
123    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
124        self.max_steps = Some(max_steps);
125        self
126    }
127
128    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
129        self.max_retries = Some(max_retries);
130        self
131    }
132
133    /// Convenience: wrap a plain `ToolCollection<NoMeta>` with an
134    /// always-execute strategy. Equivalent to
135    /// `.with_scoped_tools(ScopedCollection::auto_execute(tools))`.
136    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
137        self.scoped_collections
138            .push(Box::new(ScopedCollection::auto_execute(tools)));
139        self
140    }
141
142    /// Attach a typed tool collection with a user-defined strategy.
143    /// Multiple calls are additive — the resulting `Chat` routes each
144    /// tool call to the collection that owns its name.
145    pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
146    where
147        M: Send + Sync + 'static,
148        F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action + Send + Sync + 'static,
149    {
150        self.scoped_collections.push(Box::new(scoped));
151        self
152    }
153
154    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
155        self.retry_strategy = Some(retry_strategy);
156        self
157    }
158
159    /// Run a strategy before the loop starts. It receives the outgoing
160    /// `Messages` (which it may mutate synchronously) and the most recent
161    /// `Metadata`, and returns a future for any async side effects.
162    pub fn with_before_strategy(mut self, before_strategy: CallbackStrategy) -> Self {
163        self.before_strategy = Some(before_strategy);
164        self
165    }
166
167    /// Run a strategy after the loop completes successfully. It receives the
168    /// final `Messages` and the run's `Metadata`.
169    pub fn with_after_strategy(mut self, after_strategy: CallbackStrategy) -> Self {
170        self.after_strategy = Some(after_strategy);
171        self
172    }
173
174    pub fn with_model(mut self, model: CP) -> Self {
175        self.model = Some(model);
176        self
177    }
178
179    pub fn with_options(mut self, options: ChatOptions) -> Self {
180        self.model_options = Some(options);
181        self
182    }
183
184    pub fn build(self) -> Chat<CP, Output> {
185        let mut routing: HashMap<String, usize> = HashMap::new();
186        for (idx, coll) in self.scoped_collections.iter().enumerate() {
187            for name in coll.names() {
188                if routing.contains_key(name) {
189                    eprintln!(
190                        "chat-rs: tool name `{name}` is registered in multiple scoped \
191                         collections; keeping the first registration."
192                    );
193                    continue;
194                }
195                routing.insert(name.to_string(), idx);
196            }
197        }
198
199        Chat {
200            model: self.model.expect("Need to set a model"),
201            output_shape: self.output_shape,
202            max_steps: self.max_steps,
203            max_retries: self.max_retries,
204            retry_strategy: self.retry_strategy,
205            before_strategy: self.before_strategy,
206            after_strategy: self.after_strategy,
207            scoped_collections: self.scoped_collections,
208            routing,
209            model_options: self.model_options,
210            _output: std::marker::PhantomData,
211        }
212    }
213}
214
215impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
216    fn default() -> Self {
217        ChatBuilder {
218            model: None,
219            output_shape: None,
220            model_options: None,
221            max_steps: None,
222            max_retries: None,
223            retry_strategy: None,
224            before_strategy: None,
225            after_strategy: None,
226            scoped_collections: Vec::new(),
227            _output: std::marker::PhantomData,
228        }
229    }
230}