1use schemars::JsonSchema;
2use serde::de::DeserializeOwned;
3use tools_rs::ToolCollection;
4
5#[cfg(feature = "stream")]
6use crate::chat::state::Streamed;
7use crate::{
8 chat::{
9 Chat,
10 state::{Embedded, Structured, Unstructured},
11 },
12 traits::CompletionProvider,
13 types::{
14 callback::{CallbackStrategy, RetryStrategy},
15 options::ChatOptions,
16 },
17};
18
19#[cfg(feature = "stream")]
20use crate::traits::StreamProvider;
21
22pub struct ChatBuilder<CP: CompletionProvider, Output = Unstructured> {
23 model: Option<CP>,
24 output_shape: Option<schemars::Schema>,
25 model_options: Option<ChatOptions>,
26 max_steps: Option<u16>,
27 max_retries: Option<u16>,
28 retry_strategy: Option<RetryStrategy>,
29 before_strategy: Option<CallbackStrategy>,
30 after_strategy: Option<CallbackStrategy>,
31 tools: Option<ToolCollection>,
32 _output: std::marker::PhantomData<Output>,
33}
34
35impl<CP: CompletionProvider> ChatBuilder<CP, Unstructured> {
36 pub fn new() -> Self {
37 ChatBuilder {
38 _output: std::marker::PhantomData,
39 ..Default::default()
40 }
41 }
42
43 pub fn with_structured_output<T>(self) -> ChatBuilder<CP, Structured<T>>
44 where
45 T: JsonSchema + DeserializeOwned,
46 {
47 let shape = schemars::schema_for!(T);
48
49 ChatBuilder {
50 model: self.model,
51 max_steps: self.max_steps,
52 max_retries: self.max_retries,
53 retry_strategy: self.retry_strategy,
54 before_strategy: self.before_strategy,
55 after_strategy: self.after_strategy,
56 output_shape: Some(shape),
57 tools: self.tools,
58 model_options: self.model_options,
59 _output: std::marker::PhantomData,
60 }
61 }
62
63 #[cfg(feature = "stream")]
64 pub fn with_streamed_response(self) -> ChatBuilder<CP, Streamed>
65 where
66 CP: StreamProvider,
67 {
68 if self.output_shape.is_some() {
69 println!(
70 "Warning: Cannot call streamed responses with structured outputs. Output shape will be set to None"
71 );
72 }
73
74 ChatBuilder {
75 model: self.model,
76 max_steps: self.max_steps,
77 max_retries: self.max_retries,
78 retry_strategy: self.retry_strategy,
79 before_strategy: self.before_strategy,
80 after_strategy: self.after_strategy,
81 output_shape: None, tools: self.tools,
83 model_options: self.model_options,
84 _output: std::marker::PhantomData,
85 }
86 }
87
88 pub fn with_embeddings(self) -> ChatBuilder<CP, Embedded> {
89 if self.output_shape.is_some() {
90 println!(
91 "Warning: Cannot call embedding responses with structured outputs. Output shape will be set to None"
92 );
93 }
94
95 ChatBuilder {
96 model: self.model,
97 max_retries: self.max_retries,
98 retry_strategy: self.retry_strategy,
99 before_strategy: self.before_strategy,
100 after_strategy: self.after_strategy,
101 output_shape: None,
102 tools: None,
103 max_steps: None,
104 model_options: self.model_options,
105 _output: std::marker::PhantomData,
106 }
107 }
108}
109
110impl<CP: CompletionProvider, Output> ChatBuilder<CP, Output> {
111 pub fn with_max_steps(mut self, max_steps: u16) -> Self {
112 self.max_steps = Some(max_steps);
113 self
114 }
115
116 pub fn with_max_retries(mut self, max_retries: u16) -> Self {
117 self.max_retries = Some(max_retries);
118 self
119 }
120
121 pub fn with_tools(mut self, tools: ToolCollection) -> Self {
122 self.tools = Some(tools);
123 self
124 }
125
126 pub fn with_retry_strategy(mut self, retry_strategy: RetryStrategy) -> Self {
127 self.retry_strategy = Some(retry_strategy);
128 self
129 }
130
131 pub fn with_model(mut self, model: CP) -> Self {
132 self.model = Some(model);
133 self
134 }
135
136 pub fn with_options(mut self, options: ChatOptions) -> Self {
137 self.model_options = Some(options);
138 self
139 }
140
141 pub fn build(self) -> Chat<CP, Output> {
142 Chat {
143 model: self.model.expect("Need to set a model"),
144 output_shape: self.output_shape,
145 max_steps: self.max_steps,
146 max_retries: self.max_retries,
147 retry_strategy: self.retry_strategy,
148 before_strategy: self.before_strategy,
149 after_strategy: self.after_strategy,
150 tools: self.tools,
151 model_options: self.model_options,
152 _output: std::marker::PhantomData,
153 }
154 }
155}
156
157impl<CP: CompletionProvider> Default for ChatBuilder<CP, Unstructured> {
158 fn default() -> Self {
159 ChatBuilder {
160 model: None,
161 output_shape: None,
162 model_options: None,
163 max_steps: None,
164 max_retries: None,
165 retry_strategy: None,
166 before_strategy: None,
167 after_strategy: None,
168 tools: None,
169 _output: std::marker::PhantomData,
170 }
171 }
172}