lm_studio_api/chat/
content.rs

1use crate::prelude::*;
2
3/// The message content
4#[derive(Debug, Clone, From, Eq, PartialEq)]
5#[from(String, "Content::Text { text: value.into() }")]
6#[from(&str, "Content::Text { text: value.into() }")]
7pub enum Content {
8    Text { text: String },
9    Image { image: Image },
10}
11
12impl Content {
13    /// Adds response chunk to content
14    pub(crate) fn add_chunk(&mut self, add_text: &str) {
15        match self {
16            Content::Text { text } => text.push_str(add_text),
17            _ => {}
18        }
19    }
20}
21
22impl ::serde::Serialize for Content {
23    fn serialize<S>(&self, se: S) -> StdResult<S::Ok, S::Error>
24    where
25        S: ::serde::Serializer,
26    {
27        use serde::ser::SerializeStruct;
28
29        match self {
30            Content::Text { text } => {
31                let mut s = se.serialize_struct("Content", 2)?;
32                s.serialize_field("type", "text")?;
33                s.serialize_field("text", text)?;
34                s.end()
35            }
36            Content::Image { image } => {
37                let mut s = se.serialize_struct("Content", 2)?;
38                s.serialize_field("type", "image_url")?;
39                s.serialize_field("image_url", image)?;
40                s.end()
41            }
42        }
43    }
44}
45
46struct ContentVisitor;
47
48impl<'de> ::serde::de::Visitor<'de> for ContentVisitor {
49    type Value = Content;
50
51    fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
52        formatter.write_str("struct Content with type field")
53    }
54
55    fn visit_map<V>(self, mut map: V) -> StdResult<Self::Value, V::Error>
56    where
57        V: ::serde::de::MapAccess<'de>,
58    {
59        let mut ctype: Option<String> = None;
60        let mut text: Option<String> = None;
61        let mut image_url: Option<Image> = None;
62
63        while let Some(key) = map.next_key::<String>()? {
64            match key.as_str() {
65                "type" => {
66                    if ctype.is_some() {
67                        return Err(serde::de::Error::duplicate_field("type"));
68                    }
69                    ctype = Some(map.next_value()?);
70                }
71                "text" => {
72                    if text.is_some() {
73                        return Err(serde::de::Error::duplicate_field("text"));
74                    }
75                    text = Some(map.next_value()?);
76                }
77                "image_url" => {
78                    if image_url.is_some() {
79                        return Err(serde::de::Error::duplicate_field("image_url"));
80                    }
81                    image_url = Some(map.next_value()?);
82                }
83                _ => {
84                    let _ : serde::de::IgnoredAny = map.next_value()?;
85                }
86            }
87        }
88
89        let ctype = ctype.ok_or_else(|| serde::de::Error::missing_field("type"))?;
90
91        match ctype.as_str() {
92            "text" => {
93                let text = text.ok_or_else(|| serde::de::Error::missing_field("text"))?;
94                Ok(Content::Text { text })
95            }
96            "image_url" => {
97                let image_url = image_url.ok_or_else(|| serde::de::Error::missing_field("image_url"))?;
98                Ok(Content::Image { image: image_url })
99            }
100            _ => Err(serde::de::Error::unknown_variant(&ctype, &["text", "image_url"])),
101        }
102    }
103}
104
105impl<'de> ::serde::Deserialize<'de> for Content {
106    fn deserialize<D>(de: D) -> ::std::result::Result<Self, D::Error>
107    where
108        D: serde::Deserializer<'de>,
109    {
110        const FIELDS: &[&str] = &["type", "text", "image_url"];
111        de.deserialize_struct("Content", FIELDS, ContentVisitor)
112    }
113}