1use super::{NoOpenAIAPIKeyError, OpenAICompatibleClient};
2use crate::{
3    ChatModel, ChatSession, CreateChatSession, CreateDefaultChatConstraintsForType,
4    GenerationParameters, ModelBuilder, ModelConstraints, StructuredChatModel,
5};
6use futures_util::StreamExt;
7use kalosm_model_types::ModelLoadingProgress;
8use kalosm_sample::Schema;
9use reqwest_eventsource::{Event, RequestBuilderExt};
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{future::Future, sync::Arc};
12use thiserror::Error;
13
14#[derive(Debug)]
15struct OpenAICompatibleChatModelInner {
16    model: String,
17    client: OpenAICompatibleClient,
18}
19
20#[derive(Debug, Clone)]
22pub struct OpenAICompatibleChatModel {
23    inner: Arc<OpenAICompatibleChatModelInner>,
24}
25
26impl OpenAICompatibleChatModel {
27    pub fn builder() -> OpenAICompatibleChatModelBuilder<false> {
29        OpenAICompatibleChatModelBuilder::new()
30    }
31}
32
33#[derive(Debug, Default)]
35pub struct OpenAICompatibleChatModelBuilder<const WITH_NAME: bool> {
36    model: Option<String>,
37    client: OpenAICompatibleClient,
38}
39
40impl OpenAICompatibleChatModelBuilder<false> {
41    pub fn new() -> Self {
43        Self {
44            model: None,
45            client: Default::default(),
46        }
47    }
48}
49
50impl<const WITH_NAME: bool> OpenAICompatibleChatModelBuilder<WITH_NAME> {
51    pub fn with_model(self, model: impl ToString) -> OpenAICompatibleChatModelBuilder<true> {
53        OpenAICompatibleChatModelBuilder {
54            model: Some(model.to_string()),
55            client: self.client,
56        }
57    }
58
59    pub fn with_gpt_4o(self) -> OpenAICompatibleChatModelBuilder<true> {
61        self.with_model("gpt-4o")
62    }
63
64    pub fn with_chat_gpt_4o(self) -> OpenAICompatibleChatModelBuilder<true> {
66        self.with_model("chatgpt-4o-latest")
67    }
68
69    pub fn with_gpt_4o_mini(self) -> OpenAICompatibleChatModelBuilder<true> {
71        self.with_model("gpt-4o-mini")
72    }
73
74    pub fn with_client(mut self, client: OpenAICompatibleClient) -> Self {
76        self.client = client;
77        self
78    }
79}
80
81impl OpenAICompatibleChatModelBuilder<true> {
82    pub fn build(self) -> OpenAICompatibleChatModel {
84        OpenAICompatibleChatModel {
85            inner: Arc::new(OpenAICompatibleChatModelInner {
86                model: self.model.unwrap(),
87                client: self.client,
88            }),
89        }
90    }
91}
92
93impl ModelBuilder for OpenAICompatibleChatModelBuilder<true> {
94    type Model = OpenAICompatibleChatModel;
95    type Error = std::convert::Infallible;
96
97    async fn start_with_loading_handler(
98        self,
99        _: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
100    ) -> Result<Self::Model, Self::Error> {
101        Ok(self.build())
102    }
103
104    fn requires_download(&self) -> bool {
105        false
106    }
107}
108
109#[derive(Error, Debug)]
111pub enum OpenAICompatibleChatModelError {
112    #[error("Error resolving API key: {0}")]
114    APIKeyError(#[from] NoOpenAIAPIKeyError),
115    #[error("Error making request: {0}")]
117    ReqwestError(#[from] reqwest::Error),
118    #[error("Error receiving server side events: {0}")]
120    EventSourceError(#[from] reqwest_eventsource::Error),
121    #[error("OpenAI API returned no message choices in the response")]
123    NoMessageChoices,
124    #[error("Failed to deserialize OpenAI API response: {0}")]
126    DeserializeError(#[from] serde_json::Error),
127    #[error("Refusal from OpenAI API: {0}")]
129    Refusal(String),
130    #[error("Function calls are not yet supported in kalosm with the OpenAI API")]
132    FunctionCallsNotSupported,
133}
134
135#[derive(Serialize, Deserialize, Clone)]
137pub struct OpenAICompatibleChatSession {
138    messages: Vec<crate::ChatMessage>,
139}
140
141impl OpenAICompatibleChatSession {
142    fn new() -> Self {
143        Self {
144            messages: Vec::new(),
145        }
146    }
147}
148
149impl ChatSession for OpenAICompatibleChatSession {
150    type Error = serde_json::Error;
151
152    fn write_to(&self, into: &mut Vec<u8>) -> Result<(), Self::Error> {
153        let json = serde_json::to_vec(self)?;
154        into.extend_from_slice(&json);
155        Ok(())
156    }
157
158    fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>
159    where
160        Self: std::marker::Sized,
161    {
162        let json = serde_json::from_slice(bytes)?;
163        Ok(json)
164    }
165
166    fn history(&self) -> Vec<crate::ChatMessage> {
167        self.messages.clone()
168    }
169
170    fn try_clone(&self) -> Result<Self, Self::Error>
171    where
172        Self: std::marker::Sized,
173    {
174        Ok(self.clone())
175    }
176}
177
178impl CreateChatSession for OpenAICompatibleChatModel {
179    type ChatSession = OpenAICompatibleChatSession;
180    type Error = OpenAICompatibleChatModelError;
181
182    fn new_chat_session(&self) -> Result<Self::ChatSession, Self::Error> {
183        Ok(OpenAICompatibleChatSession::new())
184    }
185}
186
187#[derive(Serialize, Deserialize)]
188struct OpenAICompatibleChatResponse {
189    choices: Vec<OpenAICompatibleChatResponseChoice>,
190}
191
192#[derive(Serialize, Deserialize)]
193struct OpenAICompatibleChatResponseChoice {
194    delta: OpenAICompatibleChatResponseChoiceMessage,
195    finish_reason: Option<FinishReason>,
196}
197
198#[derive(Serialize, Deserialize)]
199enum FinishReason {
200    #[serde(rename = "content_filter")]
201    ContentFilter,
202    #[serde(rename = "function_call")]
203    FunctionCall,
204    #[serde(rename = "length")]
205    MaxTokens,
206    #[serde(rename = "stop")]
207    Stop,
208}
209
210#[derive(Serialize, Deserialize)]
211struct OpenAICompatibleChatResponseChoiceMessage {
212    content: Option<String>,
213    refusal: Option<String>,
214}
215
216impl ChatModel<GenerationParameters> for OpenAICompatibleChatModel {
217    fn add_messages_with_callback<'a>(
218        &'a self,
219        session: &'a mut Self::ChatSession,
220        messages: &[crate::ChatMessage],
221        sampler: GenerationParameters,
222        mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
223    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
224        let myself = &*self.inner;
225        let json = serde_json::json!({
226            "messages": messages,
227            "model": myself.model,
228            "stream": true,
229            "top_p": sampler.top_p,
230            "temperature": sampler.temperature,
231            "frequency_penalty": sampler.repetition_penalty,
232            "max_completion_tokens": if sampler.max_length == u32::MAX { None } else { Some(sampler.max_length) },
233            "stop": sampler.stop_on.clone(),
234        });
235        async move {
236            let api_key = myself.client.resolve_api_key()?;
237            let mut event_source = myself
238                .client
239                .reqwest_client
240                .post(format!("{}/chat/completions", myself.client.base_url()))
241                .header("Content-Type", "application/json")
242                .header("Authorization", format!("Bearer {}", api_key))
243                .json(&json)
244                .eventsource()
245                .unwrap();
246
247            let mut new_message_text = String::new();
248
249            while let Some(event) = event_source.next().await {
250                match event? {
251                    Event::Open => {}
252                    Event::Message(message) => {
253                        let data =
254                            serde_json::from_str::<OpenAICompatibleChatResponse>(&message.data)?;
255                        let first_choice = data
256                            .choices
257                            .into_iter()
258                            .next()
259                            .ok_or(OpenAICompatibleChatModelError::NoMessageChoices)?;
260                        if let Some(content) = first_choice.delta.refusal {
261                            return Err(OpenAICompatibleChatModelError::Refusal(content));
262                        }
263                        if let Some(refusal) = &first_choice.finish_reason {
264                            match refusal {
265                                FinishReason::ContentFilter => {
266                                    return Err(OpenAICompatibleChatModelError::Refusal(
267                                        "ContentFilter".to_string(),
268                                    ))
269                                }
270                                FinishReason::FunctionCall => {
271                                    return Err(
272                                        OpenAICompatibleChatModelError::FunctionCallsNotSupported,
273                                    )
274                                }
275                                _ => return Ok(()),
276                            }
277                        }
278                        if let Some(content) = first_choice.delta.content {
279                            new_message_text += &content;
280                            on_token(content)?;
281                        }
282                    }
283                }
284            }
285
286            let new_message =
287                crate::ChatMessage::new(crate::MessageType::UserMessage, new_message_text);
288
289            session.messages.push(new_message);
290
291            Ok(())
292        }
293    }
294}
295
296#[derive(Debug, Clone, Copy)]
298pub struct SchemaParser<P> {
299    phantom: std::marker::PhantomData<P>,
300}
301
302impl<P> Default for SchemaParser<P> {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308impl<P> SchemaParser<P> {
309    pub const fn new() -> Self {
311        Self {
312            phantom: std::marker::PhantomData,
313        }
314    }
315}
316
317impl<P> ModelConstraints for SchemaParser<P> {
318    type Output = P;
319}
320
321impl<T: Schema + DeserializeOwned> CreateDefaultChatConstraintsForType<T>
322    for OpenAICompatibleChatModel
323{
324    type DefaultConstraints = SchemaParser<T>;
325
326    fn create_default_constraints() -> Self::DefaultConstraints {
327        SchemaParser::new()
328    }
329}
330
331impl<P> StructuredChatModel<SchemaParser<P>> for OpenAICompatibleChatModel
332where
333    P: Schema + DeserializeOwned,
334{
335    fn add_message_with_callback_and_constraints<'a>(
336        &'a self,
337        session: &'a mut Self::ChatSession,
338        messages: &[crate::ChatMessage],
339        sampler: GenerationParameters,
340        _: SchemaParser<P>,
341        mut on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
342    ) -> impl Future<Output = Result<P, Self::Error>> + Send + 'a {
343        let schema = P::schema();
344        let mut schema: serde_json::Result<serde_json::Value> =
345            serde_json::from_str(&schema.to_string());
346        fn remove_unsupported_properties(schema: &mut serde_json::Value) {
347            match schema {
348                serde_json::Value::Null => {}
349                serde_json::Value::Bool(_) => {}
350                serde_json::Value::Number(_) => {}
351                serde_json::Value::String(_) => {}
352                serde_json::Value::Array(array) => {
353                    for item in array {
354                        remove_unsupported_properties(item);
355                    }
356                }
357                serde_json::Value::Object(map) => {
358                    map.retain(|key, value| {
359                        const OPEN_AI_UNSUPPORTED_PROPERTIES: [&str; 19] = [
360                            "minLength",
361                            "maxLength",
362                            "pattern",
363                            "format",
364                            "minimum",
365                            "maximum",
366                            "multipleOf",
367                            "patternProperties",
368                            "unevaluatedProperties",
369                            "propertyNames",
370                            "minProperties",
371                            "maxProperties",
372                            "unevaluatedItems",
373                            "contains",
374                            "minContains",
375                            "maxContains",
376                            "minItems",
377                            "maxItems",
378                            "uniqueItems",
379                        ];
380                        if OPEN_AI_UNSUPPORTED_PROPERTIES.contains(&key.as_str()) {
381                            return false;
382                        }
383
384                        remove_unsupported_properties(value);
385                        true
386                    });
387                }
388            }
389        }
390        if let Ok(schema) = &mut schema {
391            remove_unsupported_properties(schema);
392        }
393
394        let myself = &*self.inner;
395        let json = schema.map(|schema| serde_json::json!({
396            "messages": messages,
397            "model": myself.model,
398            "stream": true,
399            "top_p": sampler.top_p,
400            "temperature": sampler.temperature,
401            "frequency_penalty": sampler.repetition_penalty,
402            "max_completion_tokens": if sampler.max_length == u32::MAX { None } else { Some(sampler.max_length) },
403            "stop": sampler.stop_on.clone(),
404            "seed": sampler.seed(),
405            "response_format": {
406                "type": "json_schema",
407                "json_schema": {
408                    "name": "response",
409                    "schema": schema,
410                    "strict": true
411                }
412            }
413        }));
414        async move {
415            let json = json?;
416            let api_key = myself.client.resolve_api_key()?;
417            let mut event_source = myself
418                .client
419                .reqwest_client
420                .post(format!("{}/chat/completions", myself.client.base_url()))
421                .header("Content-Type", "application/json")
422                .header("Authorization", format!("Bearer {}", api_key))
423                .json(&json)
424                .eventsource()
425                .unwrap();
426
427            let mut new_message_text = String::new();
428
429            while let Some(event) = event_source.next().await {
430                match event? {
431                    Event::Open => {}
432                    Event::Message(message) => {
433                        let data =
434                            serde_json::from_str::<OpenAICompatibleChatResponse>(&message.data)?;
435                        let first_choice = data
436                            .choices
437                            .first()
438                            .ok_or(OpenAICompatibleChatModelError::NoMessageChoices)?;
439                        if let Some(content) = &first_choice.delta.refusal {
440                            return Err(OpenAICompatibleChatModelError::Refusal(content.clone()));
441                        }
442                        if let Some(refusal) = &first_choice.finish_reason {
443                            match refusal {
444                                FinishReason::ContentFilter => {
445                                    return Err(OpenAICompatibleChatModelError::Refusal(
446                                        "ContentFilter".to_string(),
447                                    ))
448                                }
449                                FinishReason::FunctionCall => {
450                                    return Err(
451                                        OpenAICompatibleChatModelError::FunctionCallsNotSupported,
452                                    )
453                                }
454                                _ => break,
455                            }
456                        }
457                        if let Some(content) = &first_choice.delta.content {
458                            on_token(content.clone())?;
459                            new_message_text += content;
460                        }
461                    }
462                }
463            }
464
465            let result = serde_json::from_str::<P>(&new_message_text)?;
466
467            let new_message =
468                crate::ChatMessage::new(crate::MessageType::UserMessage, new_message_text);
469
470            session.messages.push(new_message);
471
472            Ok(result)
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use std::sync::{Arc, RwLock};
480
481    use serde::Deserialize;
482
483    use super::{
484        ChatModel, CreateChatSession, GenerationParameters, OpenAICompatibleChatModelBuilder,
485        SchemaParser, StructuredChatModel,
486    };
487
488    #[tokio::test]
489    async fn test_gpt_4o_mini() {
490        let model = OpenAICompatibleChatModelBuilder::new()
491            .with_gpt_4o_mini()
492            .build();
493
494        let mut session = model.new_chat_session().unwrap();
495
496        let messages = vec![crate::ChatMessage::new(
497            crate::MessageType::UserMessage,
498            "Hello, world!".to_string(),
499        )];
500        let all_text = Arc::new(RwLock::new(String::new()));
501        model
502            .add_messages_with_callback(
503                &mut session,
504                &messages,
505                GenerationParameters::default().with_seed(1234),
506                {
507                    let all_text = all_text.clone();
508                    move |token| {
509                        let mut all_text = all_text.write().unwrap();
510                        all_text.push_str(&token);
511                        print!("{token}");
512                        std::io::Write::flush(&mut std::io::stdout()).unwrap();
513                        Ok(())
514                    }
515                },
516            )
517            .await
518            .unwrap();
519
520        let all_text = all_text.read().unwrap();
521        println!("{all_text}");
522
523        assert!(!all_text.is_empty());
524    }
525
526    #[tokio::test]
527    async fn test_gpt_4o_mini_constrained() {
528        let model = OpenAICompatibleChatModelBuilder::new()
529            .with_gpt_4o_mini()
530            .build();
531
532        let mut session = model.new_chat_session().unwrap();
533
534        let messages = vec![crate::ChatMessage::new(
535            crate::MessageType::UserMessage,
536            "Give me a list of 5 primes.".to_string(),
537        )];
538        let all_text = Arc::new(RwLock::new(String::new()));
539
540        #[derive(Debug, Clone, kalosm_sample::Parse, kalosm_sample::Schema, Deserialize)]
541        struct Constraints {
542            primes: Vec<u8>,
543        }
544
545        let response: Constraints = model
546            .add_message_with_callback_and_constraints(
547                &mut session,
548                &messages,
549                GenerationParameters::default(),
550                SchemaParser::new(),
551                {
552                    let all_text = all_text.clone();
553                    move |token| {
554                        let mut all_text = all_text.write().unwrap();
555                        all_text.push_str(&token);
556                        print!("{token}");
557                        std::io::Write::flush(&mut std::io::stdout()).unwrap();
558                        Ok(())
559                    }
560                },
561            )
562            .await
563            .unwrap();
564        println!("{response:?}");
565
566        let all_text = all_text.read().unwrap();
567        println!("{all_text}");
568
569        assert!(!all_text.is_empty());
570
571        assert!(!response.primes.is_empty());
572    }
573}