Skip to main content

rig_cat/provider/openai/
mod.rs

1//! `OpenAI` provider: completion and embedding models.
2
3use comp_cat_rs::effect::io::Io;
4use comp_cat_rs::effect::stream::Stream;
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8use crate::model::{
9    CompletionModel, CompletionRequest, CompletionResponse, StreamChunk,
10};
11use crate::embedding::{Embedding, EmbeddingModel, EmbeddingRequest};
12
13/// Newtype for the `OpenAI` API key.
14#[derive(Clone)]
15pub struct ApiKey(String);
16
17impl ApiKey {
18    #[must_use]
19    pub fn new(key: String) -> Self { Self(key) }
20
21    fn as_str(&self) -> &str { &self.0 }
22}
23
24/// Newtype for a model name.
25#[derive(Clone)]
26pub struct ModelName(String);
27
28impl ModelName {
29    #[must_use]
30    pub fn new(name: String) -> Self { Self(name) }
31
32    fn as_str(&self) -> &str { &self.0 }
33}
34
35/// `OpenAI` completion model.
36pub struct OpenAiCompletion {
37    api_key: ApiKey,
38    model: ModelName,
39}
40
41impl OpenAiCompletion {
42    #[must_use]
43    pub fn new(api_key: ApiKey, model: ModelName) -> Self {
44        Self { api_key, model }
45    }
46}
47
48/// `OpenAI` embedding model.
49pub struct OpenAiEmbedding {
50    api_key: ApiKey,
51    model: ModelName,
52}
53
54impl OpenAiEmbedding {
55    #[must_use]
56    pub fn new(api_key: ApiKey, model: ModelName) -> Self {
57        Self { api_key, model }
58    }
59}
60
61// --- Request/response JSON shapes ---
62
63#[derive(Serialize)]
64struct ChatRequest {
65    model: String,
66    messages: Vec<ChatMessage>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    temperature: Option<f64>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    max_tokens: Option<u32>,
71}
72
73#[derive(Serialize)]
74struct ChatMessage {
75    role: String,
76    content: String,
77}
78
79#[derive(Deserialize)]
80struct ChatResponse {
81    choices: Vec<ChatChoice>,
82    model: String,
83}
84
85#[derive(Deserialize)]
86struct ChatChoice {
87    message: ChatChoiceMessage,
88}
89
90#[derive(Deserialize)]
91struct ChatChoiceMessage {
92    content: Option<String>,
93}
94
95#[derive(Serialize)]
96struct EmbedRequest {
97    model: String,
98    input: Vec<String>,
99}
100
101#[derive(Deserialize)]
102struct EmbedResponse {
103    data: Vec<EmbedData>,
104}
105
106#[derive(Deserialize)]
107struct EmbedData {
108    embedding: Vec<f64>,
109}
110
111// --- Trait impls ---
112
113impl CompletionModel for OpenAiCompletion {
114    fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse> {
115        let api_key = self.api_key.clone();
116        let model_name = self.model.clone();
117        Io::suspend(move || {
118            let messages: Vec<ChatMessage> = request.messages().iter().map(|m| {
119                ChatMessage {
120                    role: match m.role() {
121                        crate::model::Role::System => "system".to_owned(),
122                        crate::model::Role::User => "user".to_owned(),
123                        crate::model::Role::Assistant => "assistant".to_owned(),
124                    },
125                    content: m.content().to_owned(),
126                }
127            }).collect();
128
129            let body = ChatRequest {
130                model: model_name.as_str().to_owned(),
131                messages,
132                temperature: request.temperature(),
133                max_tokens: request.max_tokens(),
134            };
135
136            let resp: ChatResponse = ureq::post("https://api.openai.com/v1/chat/completions")
137                .header("Authorization", &format!("Bearer {}", api_key.as_str()))
138                .header("Content-Type", "application/json")
139                .send_json(&body)
140                .map_err(Error::from)?
141                .into_body()
142                .read_json()
143                .map_err(Error::from)?;
144
145            let content = resp.choices.first()
146                .and_then(|c| c.message.content.clone())
147                .unwrap_or_default();
148
149            Ok(CompletionResponse::new(content, resp.model))
150        })
151    }
152
153    fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
154        // TODO: implement SSE streaming
155        Stream::empty()
156    }
157}
158
159impl EmbeddingModel for OpenAiEmbedding {
160    fn embed(&self, request: EmbeddingRequest) -> Io<Error, Vec<Embedding>> {
161        let api_key = self.api_key.clone();
162        let model_name = self.model.clone();
163        Io::suspend(move || {
164            let body = EmbedRequest {
165                model: model_name.as_str().to_owned(),
166                input: request.texts().to_vec(),
167            };
168
169            let resp: EmbedResponse = ureq::post("https://api.openai.com/v1/embeddings")
170                .header("Authorization", &format!("Bearer {}", api_key.as_str()))
171                .header("Content-Type", "application/json")
172                .send_json(&body)
173                .map_err(Error::from)?
174                .into_body()
175                .read_json()
176                .map_err(Error::from)?;
177
178            Ok(resp.data.into_iter()
179                .map(|d| Embedding::new(d.embedding))
180                .collect())
181        })
182    }
183}