artificial_openai/api_v1/
chat_completion.rs

1use artificial_core::generic::{GenericMessage, GenericRole};
2use serde::de::{self, Visitor};
3use serde::{Deserialize, Deserializer, Serialize};
4
5use std::fmt;
6
7use crate::impl_builder_methods;
8
9use super::common;
10
11#[derive(Debug, Serialize, Deserialize, Clone)]
12pub struct ChatCompletionRequest {
13    pub model: String,
14    pub messages: Vec<ChatCompletionMessage>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub temperature: Option<f64>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub top_p: Option<f64>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub n: Option<i64>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub response_format: Option<serde_json::Value>,
23}
24
25impl ChatCompletionRequest {
26    pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
27        Self {
28            model,
29            messages,
30            temperature: None,
31            top_p: None,
32            n: None,
33            response_format: None,
34        }
35    }
36}
37
38impl_builder_methods!(
39    ChatCompletionRequest,
40    temperature: f64,
41    top_p: f64,
42    n: i64,
43    response_format: serde_json::Value
44);
45
46#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
47#[serde(rename_all = "snake_case")]
48pub enum MessageRole {
49    User,
50    System,
51    Assistant,
52    Function,
53    Tool,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub enum Content {
58    Text(String),
59}
60
61impl serde::Serialize for Content {
62    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
63    where
64        S: serde::Serializer,
65    {
66        match *self {
67            Content::Text(ref text) => {
68                if text.is_empty() {
69                    serializer.serialize_none()
70                } else {
71                    serializer.serialize_str(text)
72                }
73            }
74        }
75    }
76}
77
78impl<'de> Deserialize<'de> for Content {
79    fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
80    where
81        D: Deserializer<'de>,
82    {
83        struct ContentVisitor;
84
85        impl<'de> Visitor<'de> for ContentVisitor {
86            type Value = Content;
87
88            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
89                formatter.write_str("a valid content type")
90            }
91
92            fn visit_str<E>(self, value: &str) -> Result<Content, E>
93            where
94                E: de::Error,
95            {
96                Ok(Content::Text(value.to_string()))
97            }
98
99            fn visit_none<E>(self) -> Result<Self::Value, E>
100            where
101                E: de::Error,
102            {
103                Ok(Content::Text(String::new()))
104            }
105
106            fn visit_unit<E>(self) -> Result<Self::Value, E>
107            where
108                E: de::Error,
109            {
110                Ok(Content::Text(String::new()))
111            }
112        }
113
114        deserializer.deserialize_any(ContentVisitor)
115    }
116}
117
118#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
119#[serde(rename_all = "snake_case")]
120pub enum ContentType {
121    Text,
122}
123
124#[derive(Debug, Deserialize, Serialize, Clone)]
125pub struct ChatCompletionMessage {
126    pub role: MessageRole,
127    pub content: Content,
128}
129
130#[derive(Debug, Deserialize, Serialize, Clone)]
131pub struct ChatCompletionMessageForResponse {
132    pub role: MessageRole,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub content: Option<String>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub reasoning_content: Option<String>,
137}
138
139#[derive(Debug, Deserialize, Serialize)]
140pub struct ChatCompletionChoice {
141    pub index: i64,
142    pub message: ChatCompletionMessageForResponse,
143    pub finish_reason: Option<FinishReason>,
144    pub finish_details: Option<FinishDetails>,
145}
146
147#[derive(Debug, Deserialize, Serialize)]
148pub struct ChatCompletionResponse {
149    pub id: Option<String>,
150    pub object: String,
151    pub created: i64,
152    pub model: String,
153    pub choices: Vec<ChatCompletionChoice>,
154    pub usage: common::Usage,
155    pub system_fingerprint: Option<String>,
156}
157
158#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
159#[serde(rename_all = "snake_case")]
160pub enum FinishReason {
161    Stop,
162    Length,
163    ContentFilter,
164    ToolCalls,
165    Null,
166}
167
168#[derive(Debug, Deserialize, Serialize)]
169#[allow(non_camel_case_types)]
170pub struct FinishDetails {
171    pub r#type: FinishReason,
172    pub stop: String,
173}
174
175impl From<GenericRole> for MessageRole {
176    fn from(value: GenericRole) -> Self {
177        match value {
178            GenericRole::System => MessageRole::System,
179            GenericRole::Assistant => MessageRole::Assistant,
180            GenericRole::User => MessageRole::User,
181            GenericRole::Tool => MessageRole::Tool,
182        }
183    }
184}
185
186impl From<GenericMessage> for ChatCompletionMessage {
187    fn from(value: GenericMessage) -> Self {
188        Self {
189            role: value.role.into(),
190            content: Content::Text(value.message),
191        }
192    }
193}