Skip to main content

chat_core/
builder.rs

1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::Streamed;
9use crate::{
10    chat::{
11        Chat,
12        state::{Embedded, Structured, Unstructured},
13    },
14    traits::CompletionProvider,
15    types::{
16        callback::{CallbackStrategy, RetryStrategy},
17        options::ChatOptions,
18        tools::{ScopedCollection, TypedCollection},
19    },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26    model: Option<CP>,
27    output_shape: Option<schemars::Schema>,
28    model_options: Option<ChatOptions>,
29    max_steps: Option<u16>,
30    max_retries: Option<u16>,
31    retry_strategy: Option<RetryStrategy>,
32    before_strategy: Option<CallbackStrategy>,
33    after_strategy: Option<CallbackStrategy>,
34    scoped_collections: Vec<Box<dyn TypedCollection>>,
35    _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39    pub fn new() -> Self {
40        ChatBuilder {
41            _output: std::marker::PhantomData,
42            ..Default::default()
43        }
44    }
45
46    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47    where
48        T: JsonSchema + DeserializeOwned,
49    {
50        let shape = schemars::schema_for!(T);
51
52        ChatBuilder {
53            model: self.model,
54            max_steps: self.max_steps,
55            max_retries: self.max_retries,
56            retry_strategy: self.retry_strategy,
57            before_strategy: self.before_strategy,
58            after_strategy: self.after_strategy,
59            output_shape: Some(shape),
60            scoped_collections: self.scoped_collections,
61            model_options: self.model_options,
62            _output: std::marker::PhantomData,
63        }
64    }
65
66    #[cfg(feature = "stream")]
67    pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68    where
69        CP: StreamProvider,
70    {
71        if self.output_shape.is_some() {
72            println!(
73                "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74            );
75        }
76
77        ChatBuilder {
78            model: self.model,
79            max_steps: self.max_steps,
80            max_retries: self.max_retries,
81            retry_strategy: self.retry_strategy,
82            before_strategy: self.before_strategy,
83            after_strategy: self.after_strategy,
84            output_shape: None,
85            scoped_collections: self.scoped_collections,
86            model_options: self.model_options,
87            _output: std::marker::PhantomData,
88        }
89    }
90
91    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
92        if self.output_shape.is_some() {
93            println!(
94                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
95            );
96        }
97
98        ChatBuilder {
99            model: self.model,
100            max_retries: self.max_retries,
101            retry_strategy: self.retry_strategy,
102            before_strategy: self.before_strategy,
103            after_strategy: self.after_strategy,
104            output_shape: None,
105            scoped_collections: Vec::new(),
106            max_steps: None,
107            model_options: self.model_options,
108            _output: std::marker::PhantomData,
109        }
110    }
111}
112
113impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
114    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
115        self.max_steps = Some(max_steps);
116        self
117    }
118
119    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
120        self.max_retries = Some(max_retries);
121        self
122    }
123
124    /// Convenience: wrap a plain `ToolCollection<NoMeta>` with an
125    /// always-execute strategy. Equivalent to
126    /// `.with_scoped_tools(ScopedCollection::auto_execute(tools))`.
127    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
128        self.scoped_collections
129            .push(Box::new(ScopedCollection::auto_execute(tools)));
130        self
131    }
132
133    /// Attach a typed tool collection with a user-defined strategy.
134    /// Multiple calls are additive — the resulting `Chat` routes each
135    /// tool call to the collection that owns its name.
136    pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
137    where
138        M: Send + Sync + 'static,
139        F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action
140            + Send
141            + Sync
142            + 'static,
143    {
144        self.scoped_collections.push(Box::new(scoped));
145        self
146    }
147
148    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
149        self.retry_strategy = Some(retry_strategy);
150        self
151    }
152
153    pub fn with_model(mut self, model: CP) -> Self {
154        self.model = Some(model);
155        self
156    }
157
158    pub fn with_options(mut self, options: ChatOptions) -> Self {
159        self.model_options = Some(options);
160        self
161    }
162
163    pub fn build(self) -> Chat<CP, Output> {
164        // Build routing table: tool name → index into scoped_collections.
165        // Collisions across collections are a programming error — we
166        // keep the first and ignore later ones with a warning. (Could
167        // be promoted to a hard error via a builder method later.)
168        let mut routing: HashMap<String, usize> = HashMap::new();
169        for (idx, coll) in self.scoped_collections.iter().enumerate() {
170            for name in coll.names() {
171                if routing.contains_key(name) {
172                    eprintln!(
173                        "chat-rs: tool name `{name}` is registered in multiple scoped \
174                         collections; keeping the first registration."
175                    );
176                    continue;
177                }
178                routing.insert(name.to_string(), idx);
179            }
180        }
181
182        Chat {
183            model: self.model.expect("Need to set a model"),
184            output_shape: self.output_shape,
185            max_steps: self.max_steps,
186            max_retries: self.max_retries,
187            retry_strategy: self.retry_strategy,
188            before_strategy: self.before_strategy,
189            after_strategy: self.after_strategy,
190            scoped_collections: self.scoped_collections,
191            routing,
192            model_options: self.model_options,
193            _output: std::marker::PhantomData,
194        }
195    }
196}
197
198impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
199    fn default() -> Self {
200        ChatBuilder {
201            model: None,
202            output_shape: None,
203            model_options: None,
204            max_steps: None,
205            max_retries: None,
206            retry_strategy: None,
207            before_strategy: None,
208            after_strategy: None,
209            scoped_collections: Vec::new(),
210            _output: std::marker::PhantomData,
211        }
212    }
213}