kernelx_core/providers/remote/openai/
capabilities.rs

1use async_openai::types::{
2    CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, CreateImageRequestArgs, Image,
3    ImageModel, ResponseFormat, ResponseFormatJsonSchema,
4};
5use async_trait::async_trait;
6use serde_json::Value;
7
8use crate::{
9    capabilities::{Chat, ChatMessage, Complete, ImageGen, Structured},
10    models::ModelConfig,
11    Error, Result,
12};
13
14use super::{Embed, IntoOpenAIChatMessage, OpenAI};
15
16#[async_trait]
17impl Complete for OpenAI {
18    async fn complete_impl(
19        &self,
20        model: &str,
21        prompt: &str,
22        config: &ModelConfig,
23    ) -> Result<String> {
24        let messages = vec![
25            ChatMessage::system(
26                config
27                    .system_prompt
28                    .as_deref()
29                    .unwrap_or("You are a helpful assistant."),
30            ),
31            ChatMessage::user(prompt),
32        ]
33        .into_openai_messages()?;
34
35        let request = CreateChatCompletionRequestArgs::default()
36            .model(model)
37            .messages(messages)
38            .temperature(config.temperature.unwrap_or(0.5))
39            .max_tokens(config.max_tokens.unwrap_or(256))
40            .build()?;
41
42        let res = self.client.chat().create(request).await?;
43        Ok(res
44            .choices
45            .first()
46            .and_then(|choice| choice.message.content.clone())
47            .unwrap_or_default())
48    }
49}
50
51#[async_trait]
52impl Structured for OpenAI {
53    async fn structured_impl(
54        &self,
55        model: &str,
56        prompt: &str,
57        schema: &Value,
58        config: &ModelConfig,
59    ) -> Result<Value> {
60        let messages = vec![
61            ChatMessage::system(
62                config
63                    .system_prompt
64                    .as_deref()
65                    .unwrap_or("You are a helpful assistant that outputs valid JSON."),
66            ),
67            ChatMessage::user(prompt),
68        ]
69        .into_openai_messages()?;
70
71        let response_format = ResponseFormat::JsonSchema {
72            json_schema: ResponseFormatJsonSchema {
73                name: "Response".to_string(),
74                description: None,
75                schema: Some(schema.clone()),
76                strict: Some(true),
77            },
78        };
79
80        let request = CreateChatCompletionRequestArgs::default()
81            .model(model)
82            .messages(messages)
83            .temperature(config.temperature.unwrap_or(0.5))
84            .max_tokens(config.max_tokens.unwrap_or(256))
85            .response_format(response_format)
86            .build()?;
87
88        let res = self.client.chat().create(request).await?.choices[0]
89            .message
90            .content
91            .clone()
92            .ok_or_else(|| Error::NoContent)?;
93
94        serde_json::from_str(&res).map_err(Error::Serialization)
95    }
96}
97
98#[async_trait]
99impl Chat for OpenAI {
100    async fn chat_impl(
101        &self,
102        model: &str,
103        messages: Vec<ChatMessage>,
104        config: &ModelConfig,
105    ) -> Result<String> {
106        // Create a vector with the system message from the config
107        let mut openai_messages = vec![ChatMessage::system(
108            config.system_prompt.as_deref().unwrap(),
109        )];
110
111        // Extend the vector with the provided messages
112        openai_messages.extend(messages);
113
114        // Convert the combined messages to OpenAI format
115        let openai_messages = openai_messages.into_openai_messages()?;
116
117        let request = CreateChatCompletionRequestArgs::default()
118            .model(model)
119            .messages(openai_messages)
120            .temperature(config.temperature.unwrap_or(0.7))
121            .max_tokens(config.max_tokens.unwrap_or(2048))
122            .build()?;
123
124        let response = self.client.chat().create(request).await?;
125        Ok(response
126            .choices
127            .first()
128            .and_then(|choice| choice.message.content.clone())
129            .unwrap_or_default())
130    }
131}
132
133#[async_trait]
134impl Embed for OpenAI {
135    async fn embed_impl(
136        &self,
137        _model: &str,
138        texts: Vec<String>,
139        config: &ModelConfig,
140    ) -> Result<Vec<Vec<f32>>> {
141        let request = CreateEmbeddingRequestArgs::default()
142            .model(config.model_id.as_ref().unwrap().as_str())
143            .input(texts)
144            .build()?;
145
146        let response = self.client.embeddings().create(request).await?;
147        Ok(response.data.into_iter().map(|d| d.embedding).collect())
148    }
149}
150
151#[async_trait]
152impl ImageGen for OpenAI {
153    async fn image_gen_impl(
154        &self,
155        _model: &str,
156        prompt: &str,
157        _config: &ModelConfig,
158    ) -> Result<String> {
159        let model = ImageModel::DallE3;
160
161        let request = CreateImageRequestArgs::default()
162            .model(model)
163            .prompt(prompt)
164            .build()?;
165
166        let response = self.client.images().create(request).await?;
167        match response.data.first() {
168            Some(image) => match &**image {
169                Image::Url { url, .. } => Ok(url.clone()),
170                Image::B64Json { .. } => Err(Error::UnsupportedCapability(
171                    "B64Json response format not supported.".into(),
172                )),
173            },
174            None => Err(Error::NoContent),
175        }
176    }
177}