kernelx_core/providers/remote/openai/
capabilities.rs1use async_openai::types::{
2 CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, CreateImageRequestArgs, Image,
3 ImageModel, ResponseFormat, ResponseFormatJsonSchema,
4};
5use async_trait::async_trait;
6use serde_json::Value;
7
8use crate::{
9 capabilities::{Chat, ChatMessage, Complete, ImageGen, Structured},
10 models::ModelConfig,
11 Error, Result,
12};
13
14use super::{Embed, IntoOpenAIChatMessage, OpenAI};
15
16#[async_trait]
17impl Complete for OpenAI {
18 async fn complete_impl(
19 &self,
20 model: &str,
21 prompt: &str,
22 config: &ModelConfig,
23 ) -> Result<String> {
24 let messages = vec![
25 ChatMessage::system(
26 config
27 .system_prompt
28 .as_deref()
29 .unwrap_or("You are a helpful assistant."),
30 ),
31 ChatMessage::user(prompt),
32 ]
33 .into_openai_messages()?;
34
35 let request = CreateChatCompletionRequestArgs::default()
36 .model(model)
37 .messages(messages)
38 .temperature(config.temperature.unwrap_or(0.5))
39 .max_tokens(config.max_tokens.unwrap_or(256))
40 .build()?;
41
42 let res = self.client.chat().create(request).await?;
43 Ok(res
44 .choices
45 .first()
46 .and_then(|choice| choice.message.content.clone())
47 .unwrap_or_default())
48 }
49}
50
51#[async_trait]
52impl Structured for OpenAI {
53 async fn structured_impl(
54 &self,
55 model: &str,
56 prompt: &str,
57 schema: &Value,
58 config: &ModelConfig,
59 ) -> Result<Value> {
60 let messages = vec![
61 ChatMessage::system(
62 config
63 .system_prompt
64 .as_deref()
65 .unwrap_or("You are a helpful assistant that outputs valid JSON."),
66 ),
67 ChatMessage::user(prompt),
68 ]
69 .into_openai_messages()?;
70
71 let response_format = ResponseFormat::JsonSchema {
72 json_schema: ResponseFormatJsonSchema {
73 name: "Response".to_string(),
74 description: None,
75 schema: Some(schema.clone()),
76 strict: Some(true),
77 },
78 };
79
80 let request = CreateChatCompletionRequestArgs::default()
81 .model(model)
82 .messages(messages)
83 .temperature(config.temperature.unwrap_or(0.5))
84 .max_tokens(config.max_tokens.unwrap_or(256))
85 .response_format(response_format)
86 .build()?;
87
88 let res = self.client.chat().create(request).await?.choices[0]
89 .message
90 .content
91 .clone()
92 .ok_or_else(|| Error::NoContent)?;
93
94 serde_json::from_str(&res).map_err(Error::Serialization)
95 }
96}
97
98#[async_trait]
99impl Chat for OpenAI {
100 async fn chat_impl(
101 &self,
102 model: &str,
103 messages: Vec<ChatMessage>,
104 config: &ModelConfig,
105 ) -> Result<String> {
106 let mut openai_messages = vec![ChatMessage::system(
108 config.system_prompt.as_deref().unwrap(),
109 )];
110
111 openai_messages.extend(messages);
113
114 let openai_messages = openai_messages.into_openai_messages()?;
116
117 let request = CreateChatCompletionRequestArgs::default()
118 .model(model)
119 .messages(openai_messages)
120 .temperature(config.temperature.unwrap_or(0.7))
121 .max_tokens(config.max_tokens.unwrap_or(2048))
122 .build()?;
123
124 let response = self.client.chat().create(request).await?;
125 Ok(response
126 .choices
127 .first()
128 .and_then(|choice| choice.message.content.clone())
129 .unwrap_or_default())
130 }
131}
132
133#[async_trait]
134impl Embed for OpenAI {
135 async fn embed_impl(
136 &self,
137 _model: &str,
138 texts: Vec<String>,
139 config: &ModelConfig,
140 ) -> Result<Vec<Vec<f32>>> {
141 let request = CreateEmbeddingRequestArgs::default()
142 .model(config.model_id.as_ref().unwrap().as_str())
143 .input(texts)
144 .build()?;
145
146 let response = self.client.embeddings().create(request).await?;
147 Ok(response.data.into_iter().map(|d| d.embedding).collect())
148 }
149}
150
151#[async_trait]
152impl ImageGen for OpenAI {
153 async fn image_gen_impl(
154 &self,
155 _model: &str,
156 prompt: &str,
157 _config: &ModelConfig,
158 ) -> Result<String> {
159 let model = ImageModel::DallE3;
160
161 let request = CreateImageRequestArgs::default()
162 .model(model)
163 .prompt(prompt)
164 .build()?;
165
166 let response = self.client.images().create(request).await?;
167 match response.data.first() {
168 Some(image) => match &**image {
169 Image::Url { url, .. } => Ok(url.clone()),
170 Image::B64Json { .. } => Err(Error::UnsupportedCapability(
171 "B64Json response format not supported.".into(),
172 )),
173 },
174 None => Err(Error::NoContent),
175 }
176 }
177}