harmony_protocol/
chat.rs

1use core::fmt;
2use serde::{
3    de::{self, Visitor},
4    Deserialize, Deserializer, Serialize,
5};
6use std::collections::BTreeMap;
7use std::{fmt::Display, marker::PhantomData};
8
9#[serde_with::skip_serializing_none]
10#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
11pub struct Author {
12    pub role: Role,
13    pub name: Option<String>,
14}
15
16impl Author {
17    pub fn new(role: Role, name: impl Into<String>) -> Self {
18        Self {
19            role,
20            name: Some(name.into()),
21        }
22    }
23}
24
25impl From<Role> for Author {
26    fn from(role: Role) -> Self {
27        Self { role, name: None }
28    }
29}
30
31#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
32#[serde(rename_all = "snake_case")]
33pub enum Role {
34    User,
35    Assistant,
36    System,
37    Developer,
38    Tool,
39}
40
41impl TryFrom<&str> for Role {
42    type Error = &'static str;
43    fn try_from(value: &str) -> Result<Self, Self::Error> {
44        match value {
45            "user" => Ok(Role::User),
46            "assistant" => Ok(Role::Assistant),
47            "system" => Ok(Role::System),
48            "developer" => Ok(Role::Developer),
49            "tool" => Ok(Role::Tool),
50            _ => Err("Unknown role"),
51        }
52    }
53}
54
55impl Role {
56    pub fn as_str(&self) -> &str {
57        match self {
58            Role::User => "user",
59            Role::Assistant => "assistant",
60            Role::System => "system",
61            Role::Developer => "developer",
62            Role::Tool => "tool",
63        }
64    }
65}
66
67impl Display for Role {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(f, "{}", self.as_str())
70    }
71}
72
73#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
74#[serde(rename_all = "snake_case", tag = "type")]
75pub enum Content {
76    Text(TextContent),
77    SystemContent(SystemContent),
78    DeveloperContent(DeveloperContent),
79}
80
81impl<T> From<T> for Content
82where
83    T: Into<String>,
84{
85    fn from(text: T) -> Self {
86        Self::Text(TextContent { text: text.into() })
87    }
88}
89
90impl From<SystemContent> for Content {
91    fn from(sys: SystemContent) -> Self {
92        Self::SystemContent(sys)
93    }
94}
95
96impl From<DeveloperContent> for Content {
97    fn from(dev: DeveloperContent) -> Self {
98        Self::DeveloperContent(dev)
99    }
100}
101
102#[serde_with::skip_serializing_none]
103#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
104pub struct Message {
105    #[serde(flatten)]
106    pub author: Author,
107    pub recipient: Option<String>,
108    #[serde(
109        deserialize_with = "de_string_or_content_vec",
110        serialize_with = "se_string_or_content_vec"
111    )]
112    pub content: Vec<Content>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub channel: Option<String>,
115    pub content_type: Option<String>,
116}
117
118impl Message {
119    pub fn from_author_and_content<C>(author: Author, content: C) -> Self
120    where
121        C: Into<Content>,
122    {
123        Message {
124            author,
125            content: vec![content.into()],
126            channel: None,
127            recipient: None,
128            content_type: None,
129        }
130    }
131
132    pub fn from_role_and_content<C>(role: Role, content: C) -> Self
133    where
134        C: Into<Content>,
135    {
136        Self::from_author_and_content(Author { role, name: None }, content)
137    }
138
139    pub fn from_role_and_contents<I>(role: Role, content: I) -> Self
140    where
141        I: IntoIterator<Item = Content>,
142    {
143        Message {
144            author: Author { role, name: None },
145            content: content.into_iter().collect(),
146            channel: None,
147            recipient: None,
148            content_type: None,
149        }
150    }
151
152    pub fn adding_content<C>(mut self, content: C) -> Self
153    where
154        C: Into<Content>,
155    {
156        self.content.push(content.into());
157        self
158    }
159
160    pub fn with_channel<S>(mut self, channel: S) -> Self
161    where
162        S: Into<String>,
163    {
164        self.channel = Some(channel.into());
165        self
166    }
167
168    pub fn with_recipient<S>(mut self, recipient: S) -> Self
169    where
170        S: Into<String>,
171    {
172        self.recipient = Some(recipient.into());
173        self
174    }
175
176    pub fn with_content_type<S>(mut self, content_type: S) -> Self
177    where
178        S: Into<String>,
179    {
180        self.content_type = Some(content_type.into());
181        self
182    }
183}
184
185#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
186pub struct TextContent {
187    pub text: String,
188}
189
190#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)]
191pub enum ReasoningEffort {
192    Low,
193    Medium,
194    High,
195}
196
197#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Default)]
198pub struct ChannelConfig {
199    pub valid_channels: Vec<String>,
200    pub channel_required: bool,
201}
202
203impl ChannelConfig {
204    pub fn require_channels<I, T>(channels: I) -> Self
205    where
206        I: IntoIterator<Item = T>,
207        T: Into<String>,
208    {
209        Self {
210            valid_channels: channels.into_iter().map(|c| c.into()).collect(),
211            channel_required: true,
212        }
213    }
214}
215
216#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
217pub struct ToolNamespaceConfig {
218    pub name: String,
219    pub description: Option<String>,
220    pub tools: Vec<ToolDescription>,
221}
222
223impl ToolNamespaceConfig {
224    pub fn new(
225        name: impl Into<String>,
226        description: Option<String>,
227        tools: Vec<ToolDescription>,
228    ) -> Self {
229        Self {
230            name: name.into(),
231            description,
232            tools,
233        }
234    }
235
236    pub fn browser() -> Self {
237        ToolNamespaceConfig::new(
238            "browser",
239            Some("Tool for browsing.\nThe `cursor` appears in brackets before each browsing display: `[{cursor}]`.\nCite information from the tool using the following format:\n`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\nDo not quote more than 10 words directly from the tool output.\nsources=web (default: web)".to_string()),
240            vec![
241                ToolDescription::new(
242                    "search",
243                    "Searches for information related to `query` and displays `topn` results.",
244                    Some(serde_json::json!({
245                        "type": "object",
246                        "properties": {
247                            "query": {"type": "string"},
248                            "topn": {"type": "number", "default": 10},
249                            "source": {"type": "string"}
250                        },
251                        "required": ["query"]
252                    })),
253                ),
254                ToolDescription::new(
255                    "open",
256                    "Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\nValid link ids are displayed with the formatting: `【{id}†.*】`.\nIf `cursor` is not provided, the most recent page is implied.\nIf `id` is a string, it is treated as a fully qualified URL associated with `source`.\nIf `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\nUse this function without `id` to scroll to a new location of an opened page.",
257                    Some(serde_json::json!({
258                        "type": "object",
259                        "properties": {
260                            "id": {
261                                "type": ["number", "string"],
262                                "default": -1
263                            },
264                            "cursor": {"type": "number", "default": -1},
265                            "loc": {"type": "number", "default": -1},
266                            "num_lines": {"type": "number", "default": -1},
267                            "view_source": {"type": "boolean", "default": false},
268                            "source": {"type": "string"}
269                        }
270                    })),
271                ),
272                ToolDescription::new(
273                    "find",
274                    "Finds exact matches of `pattern` in the current page, or the page given by `cursor`.",
275                    Some(serde_json::json!({
276                        "type": "object",
277                        "properties": {
278                            "pattern": {"type": "string"},
279                            "cursor": {"type": "number", "default": -1}
280                        },
281                        "required": ["pattern"]
282                    })),
283                ),
284            ],
285        )
286    }
287
288    pub fn python() -> Self {
289        ToolNamespaceConfig::new(
290            "python",
291            Some("Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.".to_string()),
292            vec![],
293        )
294    }
295}
296
297#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
298pub struct SystemContent {
299    pub model_identity: Option<String>,
300    pub reasoning_effort: Option<ReasoningEffort>,
301    pub tools: Option<BTreeMap<String, ToolNamespaceConfig>>,
302    pub conversation_start_date: Option<String>,
303    pub knowledge_cutoff: Option<String>,
304    pub channel_config: Option<ChannelConfig>,
305}
306
307impl Default for SystemContent {
308    fn default() -> Self {
309        Self {
310            model_identity: Some(
311                "You are ChatGPT, a large language model trained by OpenAI.".to_string(),
312            ),
313            reasoning_effort: Some(ReasoningEffort::Medium),
314            tools: None,
315            conversation_start_date: None,
316            knowledge_cutoff: Some("2024-06".to_string()),
317            channel_config: Some(ChannelConfig::require_channels([
318                "analysis",
319                "commentary",
320                "final",
321            ])),
322        }
323    }
324}
325
326impl SystemContent {
327    pub fn new() -> Self {
328        Default::default()
329    }
330
331    pub fn with_model_identity(mut self, model_identity: impl Into<String>) -> Self {
332        self.model_identity = Some(model_identity.into());
333        self
334    }
335
336    pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
337        self.reasoning_effort = Some(effort);
338        self
339    }
340
341    pub fn with_tools(mut self, ns_config: ToolNamespaceConfig) -> Self {
342        let ns = ns_config.name.clone();
343        if let Some(ref mut map) = self.tools {
344            map.insert(ns, ns_config);
345        } else {
346            let mut map = BTreeMap::new();
347            map.insert(ns, ns_config);
348            self.tools = Some(map);
349        }
350        self
351    }
352
353    pub fn with_conversation_start_date(
354        mut self,
355        conversation_start_date: impl Into<String>,
356    ) -> Self {
357        self.conversation_start_date = Some(conversation_start_date.into());
358        self
359    }
360
361    pub fn with_knowledge_cutoff(mut self, knowledge_cutoff: impl Into<String>) -> Self {
362        self.knowledge_cutoff = Some(knowledge_cutoff.into());
363        self
364    }
365
366    pub fn with_channel_config(mut self, channel_config: ChannelConfig) -> Self {
367        self.channel_config = Some(channel_config);
368        self
369    }
370
371    pub fn with_required_channels<I, T>(mut self, channels: I) -> Self
372    where
373        I: IntoIterator<Item = T>,
374        T: Into<String>,
375    {
376        self.channel_config = Some(ChannelConfig::require_channels(channels));
377        self
378    }
379
380    pub fn with_browser_tool(mut self) -> Self {
381        self = self.with_tools(ToolNamespaceConfig::browser());
382        self
383    }
384
385    pub fn with_python_tool(mut self) -> Self {
386        self = self.with_tools(ToolNamespaceConfig::python());
387        self
388    }
389}
390
391#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
392pub struct ToolDescription {
393    pub name: String,
394    pub description: String,
395    pub parameters: Option<serde_json::Value>,
396}
397
398impl ToolDescription {
399    pub fn new(
400        name: impl Into<String>,
401        description: impl Into<String>,
402        parameters: Option<serde_json::Value>,
403    ) -> Self {
404        Self {
405            name: name.into(),
406            description: description.into(),
407            parameters,
408        }
409    }
410}
411
412#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
413pub struct Conversation {
414    pub messages: Vec<Message>,
415}
416
417impl Conversation {
418    pub fn from_messages<I>(messages: I) -> Self
419    where
420        I: IntoIterator<Item = Message>,
421    {
422        Self {
423            messages: messages.into_iter().collect(),
424        }
425    }
426}
427
428impl<'a> IntoIterator for &'a Conversation {
429    type Item = &'a Message;
430    type IntoIter = std::slice::Iter<'a, Message>;
431
432    fn into_iter(self) -> Self::IntoIter {
433        self.messages.iter()
434    }
435}
436
437fn de_string_or_content_vec<'de, D>(deserializer: D) -> Result<Vec<Content>, D::Error>
438where
439    D: Deserializer<'de>,
440{
441    struct StringOrContentVec(PhantomData<fn() -> Vec<Content>>);
442
443    impl<'de> Visitor<'de> for StringOrContentVec {
444        type Value = Vec<Content>;
445
446        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
447            formatter.write_str("string or list of content")
448        }
449
450        fn visit_str<E>(self, value: &str) -> Result<Vec<Content>, E>
451        where
452            E: de::Error,
453        {
454            Ok(vec![Content::Text(TextContent {
455                text: value.to_owned(),
456            })])
457        }
458
459        fn visit_seq<A>(self, seq: A) -> std::result::Result<Self::Value, A::Error>
460        where
461            A: de::SeqAccess<'de>,
462        {
463            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))
464        }
465    }
466
467    deserializer.deserialize_any(StringOrContentVec(PhantomData))
468}
469
470fn se_string_or_content_vec<S>(value: &Vec<Content>, serializer: S) -> Result<S::Ok, S::Error>
471where
472    S: serde::Serializer,
473{
474    if value.len() == 1 {
475        if let Content::Text(TextContent { text }) = &value[0] {
476            return serializer.serialize_str(text);
477        }
478    }
479    value.serialize(serializer)
480}
481
482#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Default)]
483pub struct DeveloperContent {
484    pub instructions: Option<String>,
485    pub tools: Option<BTreeMap<String, ToolNamespaceConfig>>,
486}
487
488impl DeveloperContent {
489    pub fn new() -> Self {
490        Self::default()
491    }
492
493    pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
494        self.instructions = Some(instructions.into());
495        self
496    }
497
498    pub fn with_tools(mut self, ns_config: ToolNamespaceConfig) -> Self {
499        let ns = ns_config.name.clone();
500        if let Some(ref mut map) = self.tools {
501            map.insert(ns, ns_config);
502        } else {
503            let mut map = BTreeMap::new();
504            map.insert(ns, ns_config);
505            self.tools = Some(map);
506        }
507        self
508    }
509
510    pub fn with_function_tools(mut self, tools: Vec<ToolDescription>) -> Self {
511        self = self.with_tools(ToolNamespaceConfig::new("functions", None, tools));
512        self
513    }
514}