1use core::fmt;
2use serde::{
3 de::{self, Visitor},
4 Deserialize, Deserializer, Serialize,
5};
6use std::collections::BTreeMap;
7use std::{fmt::Display, marker::PhantomData};
8
9#[serde_with::skip_serializing_none]
10#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
11pub struct Author {
12 pub role: Role,
13 pub name: Option<String>,
14}
15
16impl Author {
17 pub fn new(role: Role, name: impl Into<String>) -> Self {
18 Self {
19 role,
20 name: Some(name.into()),
21 }
22 }
23}
24
25impl From<Role> for Author {
26 fn from(role: Role) -> Self {
27 Self { role, name: None }
28 }
29}
30
31#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
32#[serde(rename_all = "snake_case")]
33pub enum Role {
34 User,
35 Assistant,
36 System,
37 Developer,
38 Tool,
39}
40
41impl TryFrom<&str> for Role {
42 type Error = &'static str;
43 fn try_from(value: &str) -> Result<Self, Self::Error> {
44 match value {
45 "user" => Ok(Role::User),
46 "assistant" => Ok(Role::Assistant),
47 "system" => Ok(Role::System),
48 "developer" => Ok(Role::Developer),
49 "tool" => Ok(Role::Tool),
50 _ => Err("Unknown role"),
51 }
52 }
53}
54
55impl Role {
56 pub fn as_str(&self) -> &str {
57 match self {
58 Role::User => "user",
59 Role::Assistant => "assistant",
60 Role::System => "system",
61 Role::Developer => "developer",
62 Role::Tool => "tool",
63 }
64 }
65}
66
67impl Display for Role {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(f, "{}", self.as_str())
70 }
71}
72
73#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
74#[serde(rename_all = "snake_case", tag = "type")]
75pub enum Content {
76 Text(TextContent),
77 SystemContent(SystemContent),
78 DeveloperContent(DeveloperContent),
79}
80
81impl<T> From<T> for Content
82where
83 T: Into<String>,
84{
85 fn from(text: T) -> Self {
86 Self::Text(TextContent { text: text.into() })
87 }
88}
89
90impl From<SystemContent> for Content {
91 fn from(sys: SystemContent) -> Self {
92 Self::SystemContent(sys)
93 }
94}
95
96impl From<DeveloperContent> for Content {
97 fn from(dev: DeveloperContent) -> Self {
98 Self::DeveloperContent(dev)
99 }
100}
101
102#[serde_with::skip_serializing_none]
103#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
104pub struct Message {
105 #[serde(flatten)]
106 pub author: Author,
107 pub recipient: Option<String>,
108 #[serde(
109 deserialize_with = "de_string_or_content_vec",
110 serialize_with = "se_string_or_content_vec"
111 )]
112 pub content: Vec<Content>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 pub channel: Option<String>,
115 pub content_type: Option<String>,
116}
117
118impl Message {
119 pub fn from_author_and_content<C>(author: Author, content: C) -> Self
120 where
121 C: Into<Content>,
122 {
123 Message {
124 author,
125 content: vec![content.into()],
126 channel: None,
127 recipient: None,
128 content_type: None,
129 }
130 }
131
132 pub fn from_role_and_content<C>(role: Role, content: C) -> Self
133 where
134 C: Into<Content>,
135 {
136 Self::from_author_and_content(Author { role, name: None }, content)
137 }
138
139 pub fn from_role_and_contents<I>(role: Role, content: I) -> Self
140 where
141 I: IntoIterator<Item = Content>,
142 {
143 Message {
144 author: Author { role, name: None },
145 content: content.into_iter().collect(),
146 channel: None,
147 recipient: None,
148 content_type: None,
149 }
150 }
151
152 pub fn adding_content<C>(mut self, content: C) -> Self
153 where
154 C: Into<Content>,
155 {
156 self.content.push(content.into());
157 self
158 }
159
160 pub fn with_channel<S>(mut self, channel: S) -> Self
161 where
162 S: Into<String>,
163 {
164 self.channel = Some(channel.into());
165 self
166 }
167
168 pub fn with_recipient<S>(mut self, recipient: S) -> Self
169 where
170 S: Into<String>,
171 {
172 self.recipient = Some(recipient.into());
173 self
174 }
175
176 pub fn with_content_type<S>(mut self, content_type: S) -> Self
177 where
178 S: Into<String>,
179 {
180 self.content_type = Some(content_type.into());
181 self
182 }
183}
184
185#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
186pub struct TextContent {
187 pub text: String,
188}
189
190#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)]
191pub enum ReasoningEffort {
192 Low,
193 Medium,
194 High,
195}
196
197#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Default)]
198pub struct ChannelConfig {
199 pub valid_channels: Vec<String>,
200 pub channel_required: bool,
201}
202
203impl ChannelConfig {
204 pub fn require_channels<I, T>(channels: I) -> Self
205 where
206 I: IntoIterator<Item = T>,
207 T: Into<String>,
208 {
209 Self {
210 valid_channels: channels.into_iter().map(|c| c.into()).collect(),
211 channel_required: true,
212 }
213 }
214}
215
216#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
217pub struct ToolNamespaceConfig {
218 pub name: String,
219 pub description: Option<String>,
220 pub tools: Vec<ToolDescription>,
221}
222
223impl ToolNamespaceConfig {
224 pub fn new(
225 name: impl Into<String>,
226 description: Option<String>,
227 tools: Vec<ToolDescription>,
228 ) -> Self {
229 Self {
230 name: name.into(),
231 description,
232 tools,
233 }
234 }
235
236 pub fn browser() -> Self {
237 ToolNamespaceConfig::new(
238 "browser",
239 Some("Tool for browsing.\nThe `cursor` appears in brackets before each browsing display: `[{cursor}]`.\nCite information from the tool using the following format:\n`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\nDo not quote more than 10 words directly from the tool output.\nsources=web (default: web)".to_string()),
240 vec![
241 ToolDescription::new(
242 "search",
243 "Searches for information related to `query` and displays `topn` results.",
244 Some(serde_json::json!({
245 "type": "object",
246 "properties": {
247 "query": {"type": "string"},
248 "topn": {"type": "number", "default": 10},
249 "source": {"type": "string"}
250 },
251 "required": ["query"]
252 })),
253 ),
254 ToolDescription::new(
255 "open",
256 "Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\nValid link ids are displayed with the formatting: `【{id}†.*】`.\nIf `cursor` is not provided, the most recent page is implied.\nIf `id` is a string, it is treated as a fully qualified URL associated with `source`.\nIf `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\nUse this function without `id` to scroll to a new location of an opened page.",
257 Some(serde_json::json!({
258 "type": "object",
259 "properties": {
260 "id": {
261 "type": ["number", "string"],
262 "default": -1
263 },
264 "cursor": {"type": "number", "default": -1},
265 "loc": {"type": "number", "default": -1},
266 "num_lines": {"type": "number", "default": -1},
267 "view_source": {"type": "boolean", "default": false},
268 "source": {"type": "string"}
269 }
270 })),
271 ),
272 ToolDescription::new(
273 "find",
274 "Finds exact matches of `pattern` in the current page, or the page given by `cursor`.",
275 Some(serde_json::json!({
276 "type": "object",
277 "properties": {
278 "pattern": {"type": "string"},
279 "cursor": {"type": "number", "default": -1}
280 },
281 "required": ["pattern"]
282 })),
283 ),
284 ],
285 )
286 }
287
288 pub fn python() -> Self {
289 ToolNamespaceConfig::new(
290 "python",
291 Some("Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.".to_string()),
292 vec![],
293 )
294 }
295}
296
297#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
298pub struct SystemContent {
299 pub model_identity: Option<String>,
300 pub reasoning_effort: Option<ReasoningEffort>,
301 pub tools: Option<BTreeMap<String, ToolNamespaceConfig>>,
302 pub conversation_start_date: Option<String>,
303 pub knowledge_cutoff: Option<String>,
304 pub channel_config: Option<ChannelConfig>,
305}
306
307impl Default for SystemContent {
308 fn default() -> Self {
309 Self {
310 model_identity: Some(
311 "You are ChatGPT, a large language model trained by OpenAI.".to_string(),
312 ),
313 reasoning_effort: Some(ReasoningEffort::Medium),
314 tools: None,
315 conversation_start_date: None,
316 knowledge_cutoff: Some("2024-06".to_string()),
317 channel_config: Some(ChannelConfig::require_channels([
318 "analysis",
319 "commentary",
320 "final",
321 ])),
322 }
323 }
324}
325
326impl SystemContent {
327 pub fn new() -> Self {
328 Default::default()
329 }
330
331 pub fn with_model_identity(mut self, model_identity: impl Into<String>) -> Self {
332 self.model_identity = Some(model_identity.into());
333 self
334 }
335
336 pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
337 self.reasoning_effort = Some(effort);
338 self
339 }
340
341 pub fn with_tools(mut self, ns_config: ToolNamespaceConfig) -> Self {
342 let ns = ns_config.name.clone();
343 if let Some(ref mut map) = self.tools {
344 map.insert(ns, ns_config);
345 } else {
346 let mut map = BTreeMap::new();
347 map.insert(ns, ns_config);
348 self.tools = Some(map);
349 }
350 self
351 }
352
353 pub fn with_conversation_start_date(
354 mut self,
355 conversation_start_date: impl Into<String>,
356 ) -> Self {
357 self.conversation_start_date = Some(conversation_start_date.into());
358 self
359 }
360
361 pub fn with_knowledge_cutoff(mut self, knowledge_cutoff: impl Into<String>) -> Self {
362 self.knowledge_cutoff = Some(knowledge_cutoff.into());
363 self
364 }
365
366 pub fn with_channel_config(mut self, channel_config: ChannelConfig) -> Self {
367 self.channel_config = Some(channel_config);
368 self
369 }
370
371 pub fn with_required_channels<I, T>(mut self, channels: I) -> Self
372 where
373 I: IntoIterator<Item = T>,
374 T: Into<String>,
375 {
376 self.channel_config = Some(ChannelConfig::require_channels(channels));
377 self
378 }
379
380 pub fn with_browser_tool(mut self) -> Self {
381 self = self.with_tools(ToolNamespaceConfig::browser());
382 self
383 }
384
385 pub fn with_python_tool(mut self) -> Self {
386 self = self.with_tools(ToolNamespaceConfig::python());
387 self
388 }
389}
390
391#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
392pub struct ToolDescription {
393 pub name: String,
394 pub description: String,
395 pub parameters: Option<serde_json::Value>,
396}
397
398impl ToolDescription {
399 pub fn new(
400 name: impl Into<String>,
401 description: impl Into<String>,
402 parameters: Option<serde_json::Value>,
403 ) -> Self {
404 Self {
405 name: name.into(),
406 description: description.into(),
407 parameters,
408 }
409 }
410}
411
412#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
413pub struct Conversation {
414 pub messages: Vec<Message>,
415}
416
417impl Conversation {
418 pub fn from_messages<I>(messages: I) -> Self
419 where
420 I: IntoIterator<Item = Message>,
421 {
422 Self {
423 messages: messages.into_iter().collect(),
424 }
425 }
426}
427
428impl<'a> IntoIterator for &'a Conversation {
429 type Item = &'a Message;
430 type IntoIter = std::slice::Iter<'a, Message>;
431
432 fn into_iter(self) -> Self::IntoIter {
433 self.messages.iter()
434 }
435}
436
437fn de_string_or_content_vec<'de, D>(deserializer: D) -> Result<Vec<Content>, D::Error>
438where
439 D: Deserializer<'de>,
440{
441 struct StringOrContentVec(PhantomData<fn() -> Vec<Content>>);
442
443 impl<'de> Visitor<'de> for StringOrContentVec {
444 type Value = Vec<Content>;
445
446 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
447 formatter.write_str("string or list of content")
448 }
449
450 fn visit_str<E>(self, value: &str) -> Result<Vec<Content>, E>
451 where
452 E: de::Error,
453 {
454 Ok(vec![Content::Text(TextContent {
455 text: value.to_owned(),
456 })])
457 }
458
459 fn visit_seq<A>(self, seq: A) -> std::result::Result<Self::Value, A::Error>
460 where
461 A: de::SeqAccess<'de>,
462 {
463 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
464 }
465 }
466
467 deserializer.deserialize_any(StringOrContentVec(PhantomData))
468}
469
470fn se_string_or_content_vec<S>(value: &Vec<Content>, serializer: S) -> Result<S::Ok, S::Error>
471where
472 S: serde::Serializer,
473{
474 if value.len() == 1 {
475 if let Content::Text(TextContent { text }) = &value[0] {
476 return serializer.serialize_str(text);
477 }
478 }
479 value.serialize(serializer)
480}
481
482#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Default)]
483pub struct DeveloperContent {
484 pub instructions: Option<String>,
485 pub tools: Option<BTreeMap<String, ToolNamespaceConfig>>,
486}
487
488impl DeveloperContent {
489 pub fn new() -> Self {
490 Self::default()
491 }
492
493 pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
494 self.instructions = Some(instructions.into());
495 self
496 }
497
498 pub fn with_tools(mut self, ns_config: ToolNamespaceConfig) -> Self {
499 let ns = ns_config.name.clone();
500 if let Some(ref mut map) = self.tools {
501 map.insert(ns, ns_config);
502 } else {
503 let mut map = BTreeMap::new();
504 map.insert(ns, ns_config);
505 self.tools = Some(map);
506 }
507 self
508 }
509
510 pub fn with_function_tools(mut self, tools: Vec<ToolDescription>) -> Self {
511 self = self.with_tools(ToolNamespaceConfig::new("functions", None, tools));
512 self
513 }
514}