rig_cat/provider/openai/
mod.rs1use comp_cat_rs::effect::io::Io;
4use comp_cat_rs::effect::stream::Stream;
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8use crate::model::{
9 CompletionModel, CompletionRequest, CompletionResponse, StreamChunk,
10};
11use crate::embedding::{Embedding, EmbeddingModel, EmbeddingRequest};
12
13#[derive(Clone)]
15pub struct ApiKey(String);
16
17impl ApiKey {
18 #[must_use]
19 pub fn new(key: String) -> Self { Self(key) }
20
21 fn as_str(&self) -> &str { &self.0 }
22}
23
24#[derive(Clone)]
26pub struct ModelName(String);
27
28impl ModelName {
29 #[must_use]
30 pub fn new(name: String) -> Self { Self(name) }
31
32 fn as_str(&self) -> &str { &self.0 }
33}
34
35pub struct OpenAiCompletion {
37 api_key: ApiKey,
38 model: ModelName,
39}
40
41impl OpenAiCompletion {
42 #[must_use]
43 pub fn new(api_key: ApiKey, model: ModelName) -> Self {
44 Self { api_key, model }
45 }
46}
47
48pub struct OpenAiEmbedding {
50 api_key: ApiKey,
51 model: ModelName,
52}
53
54impl OpenAiEmbedding {
55 #[must_use]
56 pub fn new(api_key: ApiKey, model: ModelName) -> Self {
57 Self { api_key, model }
58 }
59}
60
61#[derive(Serialize)]
64struct ChatRequest {
65 model: String,
66 messages: Vec<ChatMessage>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 temperature: Option<f64>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 max_tokens: Option<u32>,
71}
72
73#[derive(Serialize)]
74struct ChatMessage {
75 role: String,
76 content: String,
77}
78
79#[derive(Deserialize)]
80struct ChatResponse {
81 choices: Vec<ChatChoice>,
82 model: String,
83}
84
85#[derive(Deserialize)]
86struct ChatChoice {
87 message: ChatChoiceMessage,
88}
89
90#[derive(Deserialize)]
91struct ChatChoiceMessage {
92 content: Option<String>,
93}
94
95#[derive(Serialize)]
96struct EmbedRequest {
97 model: String,
98 input: Vec<String>,
99}
100
101#[derive(Deserialize)]
102struct EmbedResponse {
103 data: Vec<EmbedData>,
104}
105
106#[derive(Deserialize)]
107struct EmbedData {
108 embedding: Vec<f64>,
109}
110
111impl CompletionModel for OpenAiCompletion {
114 fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse> {
115 let api_key = self.api_key.clone();
116 let model_name = self.model.clone();
117 Io::suspend(move || {
118 let messages: Vec<ChatMessage> = request.messages().iter().map(|m| {
119 ChatMessage {
120 role: match m.role() {
121 crate::model::Role::System => "system".to_owned(),
122 crate::model::Role::User => "user".to_owned(),
123 crate::model::Role::Assistant => "assistant".to_owned(),
124 },
125 content: m.content().to_owned(),
126 }
127 }).collect();
128
129 let body = ChatRequest {
130 model: model_name.as_str().to_owned(),
131 messages,
132 temperature: request.temperature(),
133 max_tokens: request.max_tokens(),
134 };
135
136 let resp: ChatResponse = ureq::post("https://api.openai.com/v1/chat/completions")
137 .header("Authorization", &format!("Bearer {}", api_key.as_str()))
138 .header("Content-Type", "application/json")
139 .send_json(&body)
140 .map_err(Error::from)?
141 .into_body()
142 .read_json()
143 .map_err(Error::from)?;
144
145 let content = resp.choices.first()
146 .and_then(|c| c.message.content.clone())
147 .unwrap_or_default();
148
149 Ok(CompletionResponse::new(content, resp.model))
150 })
151 }
152
153 fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
154 Stream::empty()
156 }
157}
158
159impl EmbeddingModel for OpenAiEmbedding {
160 fn embed(&self, request: EmbeddingRequest) -> Io<Error, Vec<Embedding>> {
161 let api_key = self.api_key.clone();
162 let model_name = self.model.clone();
163 Io::suspend(move || {
164 let body = EmbedRequest {
165 model: model_name.as_str().to_owned(),
166 input: request.texts().to_vec(),
167 };
168
169 let resp: EmbedResponse = ureq::post("https://api.openai.com/v1/embeddings")
170 .header("Authorization", &format!("Bearer {}", api_key.as_str()))
171 .header("Content-Type", "application/json")
172 .send_json(&body)
173 .map_err(Error::from)?
174 .into_body()
175 .read_json()
176 .map_err(Error::from)?;
177
178 Ok(resp.data.into_iter()
179 .map(|d| Embedding::new(d.embedding))
180 .collect())
181 })
182 }
183}