1use std::collections::HashMap;
2
3use schemars::JsonSchema;
4use serde::de::DeserializeOwned;
5use tools_rs::ToolCollection;
6
7#[cfg(feature = "stream")]
8use crate::chat::state::Streamed;
9use crate::{
10 chat::{
11 Chat,
12 state::{Embedded, Structured, Unstructured},
13 },
14 traits::CompletionProvider,
15 types::{
16 callback::{CallbackStrategy, RetryStrategy},
17 options::ChatOptions,
18 tools::{ScopedCollection, TypedCollection},
19 },
20};
21
22#[cfg(feature = "stream")]
23use crate::traits::StreamProvider;
24
25pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
26 model: Option<CP>,
27 output_shape: Option<schemars::Schema>,
28 model_options: Option<ChatOptions>,
29 max_steps: Option<u16>,
30 max_retries: Option<u16>,
31 retry_strategy: Option<RetryStrategy>,
32 before_strategy: Option<CallbackStrategy>,
33 after_strategy: Option<CallbackStrategy>,
34 scoped_collections: Vec<Box<dyn TypedCollection>>,
35 _output: std::marker::PhantomData<Output>,
36}
37
38impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
39 pub fn new() -> Self {
40 ChatBuilder {
41 _output: std::marker::PhantomData,
42 ..Default::default()
43 }
44 }
45
46 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
47 where
48 T: JsonSchema + DeserializeOwned,
49 {
50 let shape = schemars::schema_for!(T);
51
52 ChatBuilder {
53 model: self.model,
54 max_steps: self.max_steps,
55 max_retries: self.max_retries,
56 retry_strategy: self.retry_strategy,
57 before_strategy: self.before_strategy,
58 after_strategy: self.after_strategy,
59 output_shape: Some(shape),
60 scoped_collections: self.scoped_collections,
61 model_options: self.model_options,
62 _output: std::marker::PhantomData,
63 }
64 }
65
66 #[cfg(feature = "stream")]
67 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
68 where
69 CP: StreamProvider,
70 {
71 if self.output_shape.is_some() {
72 println!(
73 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
74 );
75 }
76
77 ChatBuilder {
78 model: self.model,
79 max_steps: self.max_steps,
80 max_retries: self.max_retries,
81 retry_strategy: self.retry_strategy,
82 before_strategy: self.before_strategy,
83 after_strategy: self.after_strategy,
84 output_shape: None,
85 scoped_collections: self.scoped_collections,
86 model_options: self.model_options,
87 _output: std::marker::PhantomData,
88 }
89 }
90
91 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
92 if self.output_shape.is_some() {
93 println!(
94 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
95 );
96 }
97
98 ChatBuilder {
99 model: self.model,
100 max_retries: self.max_retries,
101 retry_strategy: self.retry_strategy,
102 before_strategy: self.before_strategy,
103 after_strategy: self.after_strategy,
104 output_shape: None,
105 scoped_collections: Vec::new(),
106 max_steps: None,
107 model_options: self.model_options,
108 _output: std::marker::PhantomData,
109 }
110 }
111}
112
113impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
114 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
115 self.max_steps = Some(max_steps);
116 self
117 }
118
119 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
120 self.max_retries = Some(max_retries);
121 self
122 }
123
124 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
128 self.scoped_collections
129 .push(Box::new(ScopedCollection::auto_execute(tools)));
130 self
131 }
132
133 pub fn with_scoped_tools<M, F>(mut self, scoped: ScopedCollection<M, F>) -> Self
137 where
138 M: Send + Sync + 'static,
139 F: Fn(&tools_rs::FunctionCall, &M) -> crate::types::tools::Action
140 + Send
141 + Sync
142 + 'static,
143 {
144 self.scoped_collections.push(Box::new(scoped));
145 self
146 }
147
148 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
149 self.retry_strategy = Some(retry_strategy);
150 self
151 }
152
153 pub fn with_model(mut self, model: CP) -> Self {
154 self.model = Some(model);
155 self
156 }
157
158 pub fn with_options(mut self, options: ChatOptions) -> Self {
159 self.model_options = Some(options);
160 self
161 }
162
163 pub fn build(self) -> Chat<CP, Output> {
164 let mut routing: HashMap<String, usize> = HashMap::new();
169 for (idx, coll) in self.scoped_collections.iter().enumerate() {
170 for name in coll.names() {
171 if routing.contains_key(name) {
172 eprintln!(
173 "chat-rs: tool name `{name}` is registered in multiple scoped \
174 collections; keeping the first registration."
175 );
176 continue;
177 }
178 routing.insert(name.to_string(), idx);
179 }
180 }
181
182 Chat {
183 model: self.model.expect("Need to set a model"),
184 output_shape: self.output_shape,
185 max_steps: self.max_steps,
186 max_retries: self.max_retries,
187 retry_strategy: self.retry_strategy,
188 before_strategy: self.before_strategy,
189 after_strategy: self.after_strategy,
190 scoped_collections: self.scoped_collections,
191 routing,
192 model_options: self.model_options,
193 _output: std::marker::PhantomData,
194 }
195 }
196}
197
198impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
199 fn default() -> Self {
200 ChatBuilder {
201 model: None,
202 output_shape: None,
203 model_options: None,
204 max_steps: None,
205 max_retries: None,
206 retry_strategy: None,
207 before_strategy: None,
208 after_strategy: None,
209 scoped_collections: Vec::new(),
210 _output: std::marker::PhantomData,
211 }
212 }
213}