Skip to main content

chat_core/
builder.rs

1use schemars::JsonSchema;
2use serde::de::DeserializeOwned;
3use tools_rs::ToolCollection;
4
5#[cfg(feature = "stream")]
6use crate::chat::state::Streamed;
7use crate::{
8    chat::{
9        Chat,
10        state::{Embedded, Structured, Unstructured},
11    },
12    traits::CompletionProvider,
13    types::{
14        callback::{CallbackStrategy, RetryStrategy},
15        options::ChatOptions,
16    },
17};
18
19#[cfg(feature = "stream")]
20use crate::traits::StreamProvider;
21
22pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
23    model: Option<CP>,
24    output_shape: Option<schemars::Schema>,
25    model_options: Option<ChatOptions>,
26    max_steps: Option<u16>,
27    max_retries: Option<u16>,
28    retry_strategy: Option<RetryStrategy>,
29    before_strategy: Option<CallbackStrategy>,
30    after_strategy: Option<CallbackStrategy>,
31    tools: Option<ToolCollection>,
32    _output: std::marker::PhantomData<Output>,
33}
34
35impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
36    pub fn new() -> Self {
37        ChatBuilder {
38            _output: std::marker::PhantomData,
39            ..Default::default()
40        }
41    }
42
43    pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
44    where
45        T: JsonSchema + DeserializeOwned,
46    {
47        let shape = schemars::schema_for!(T);
48
49        ChatBuilder {
50            model: self.model,
51            max_steps: self.max_steps,
52            max_retries: self.max_retries,
53            retry_strategy: self.retry_strategy,
54            before_strategy: self.before_strategy,
55            after_strategy: self.after_strategy,
56            output_shape: Some(shape),
57            tools: self.tools,
58            model_options: self.model_options,
59            _output: std::marker::PhantomData,
60        }
61    }
62
63    #[cfg(feature = "stream")]
64    pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
65    where
66        CP: StreamProvider,
67    {
68        if self.output_shape.is_some() {
69            println!(
70                "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
71            );
72        }
73
74        ChatBuilder {
75            model: self.model,
76            max_steps: self.max_steps,
77            max_retries: self.max_retries,
78            retry_strategy: self.retry_strategy,
79            before_strategy: self.before_strategy,
80            after_strategy: self.after_strategy,
81            output_shape: None, // No shape for pure streaming
82            tools: self.tools,
83            model_options: self.model_options,
84            _output: std::marker::PhantomData,
85        }
86    }
87
88    pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
89        if self.output_shape.is_some() {
90            println!(
91                "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
92            );
93        }
94
95        ChatBuilder {
96            model: self.model,
97            max_retries: self.max_retries,
98            retry_strategy: self.retry_strategy,
99            before_strategy: self.before_strategy,
100            after_strategy: self.after_strategy,
101            output_shape: None,
102            tools: None,
103            max_steps: None,
104            model_options: self.model_options,
105            _output: std::marker::PhantomData,
106        }
107    }
108}
109
110impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
111    pub fn with_max_steps(mut self, max_steps: u16) -> Self {
112        self.max_steps = Some(max_steps);
113        self
114    }
115
116    pub fn with_max_retries(mut self, max_retries: u16) -> Self {
117        self.max_retries = Some(max_retries);
118        self
119    }
120
121    pub fn with_tools(mut self, tools: ToolCollection) -> Self {
122        self.tools = Some(tools);
123        self
124    }
125
126    pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
127        self.retry_strategy = Some(retry_strategy);
128        self
129    }
130
131    pub fn with_model(mut self, model: CP) -> Self {
132        self.model = Some(model);
133        self
134    }
135
136    pub fn with_options(mut self, options: ChatOptions) -> Self {
137        self.model_options = Some(options);
138        self
139    }
140
141    pub fn build(self) -> Chat<CP, Output> {
142        Chat {
143            model: self.model.expect("Need to set a model"),
144            output_shape: self.output_shape,
145            max_steps: self.max_steps,
146            max_retries: self.max_retries,
147            retry_strategy: self.retry_strategy,
148            before_strategy: self.before_strategy,
149            after_strategy: self.after_strategy,
150            tools: self.tools,
151            model_options: self.model_options,
152            _output: std::marker::PhantomData,
153        }
154    }
155}
156
157impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
158    fn default() -> Self {
159        ChatBuilder {
160            model: None,
161            output_shape: None,
162            model_options: None,
163            max_steps: None,
164            max_retries: None,
165            retry_strategy: None,
166            before_strategy: None,
167            after_strategy: None,
168            tools: None,
169            _output: std::marker::PhantomData,
170        }
171    }
172}