1use crate::api::ChatProvider;
2use crate::types::AiLibError;
3use crate::types::{ChatCompletionRequest, ChatCompletionResponse, Message, Role};
4use async_trait::async_trait;
5use futures::Stream;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8
9pub struct AI21Adapter {
14 client: Client,
15 api_key: String,
16}
17
18impl AI21Adapter {
19 pub fn new() -> Result<Self, AiLibError> {
20 let api_key = std::env::var("AI21_API_KEY").map_err(|_| {
21 AiLibError::ConfigurationError("AI21_API_KEY environment variable not set".to_string())
22 })?;
23
24 Ok(Self {
25 client: Client::new(),
26 api_key,
27 })
28 }
29
30 pub async fn chat_completion(
31 &self,
32 request: ChatCompletionRequest,
33 ) -> Result<ChatCompletionResponse, AiLibError> {
34 let ai21_request = self.convert_request(&request)?;
35
36 let response = self
37 .client
38 .post("https://api.ai21.com/studio/v1/chat/completions")
39 .header("Authorization", format!("Bearer {}", self.api_key))
40 .header("Content-Type", "application/json")
41 .json(&ai21_request)
42 .send()
43 .await
44 .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
45
46 if !response.status().is_success() {
47 let status = response.status();
48 let error_text = response
49 .text()
50 .await
51 .unwrap_or_else(|_| "Unknown error".to_string());
52 return Err(AiLibError::ProviderError(format!(
53 "AI21 API error {}: {}",
54 status, error_text
55 )));
56 }
57
58 let ai21_response: AI21Response = response.json().await.map_err(|e| {
59 AiLibError::DeserializationError(format!("Failed to parse AI21 response: {}", e))
60 })?;
61
62 self.convert_response(ai21_response)
63 }
64
65 pub async fn chat_completion_stream(
66 &self,
67 request: ChatCompletionRequest,
68 ) -> Result<
69 Box<
70 dyn futures::Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>>
71 + Send
72 + Unpin,
73 >,
74 AiLibError,
75 > {
76 let mut ai21_request = self.convert_request(&request)?;
77 ai21_request.stream = Some(true);
78
79 let response = self
80 .client
81 .post("https://api.ai21.com/studio/v1/chat/completions")
82 .header("Authorization", format!("Bearer {}", self.api_key))
83 .header("Content-Type", "application/json")
84 .json(&ai21_request)
85 .send()
86 .await
87 .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
88
89 if !response.status().is_success() {
90 let status = response.status();
91 let error_text = response
92 .text()
93 .await
94 .unwrap_or_else(|_| "Unknown error".to_string());
95 return Err(AiLibError::ProviderError(format!(
96 "AI21 API error {}: {}",
97 status, error_text
98 )));
99 }
100
101 let response = self.chat_completion(request.clone()).await?;
103
104 let chunk = crate::api::ChatCompletionChunk {
106 id: response.id.clone(),
107 object: "chat.completion.chunk".to_string(),
108 created: response.created,
109 model: response.model.clone(),
110 choices: response
111 .choices
112 .into_iter()
113 .map(|choice| crate::api::ChoiceDelta {
114 index: choice.index,
115 delta: crate::api::MessageDelta {
116 role: Some(choice.message.role),
117 content: Some(match &choice.message.content {
118 crate::Content::Text(text) => text.clone(),
119 _ => "".to_string(),
120 }),
121 },
122 finish_reason: choice.finish_reason,
123 })
124 .collect(),
125 };
126
127 let stream = futures::stream::once(async move { Ok(chunk) });
128 Ok(Box::new(Box::pin(stream)))
129 }
130
131 fn convert_request(&self, request: &ChatCompletionRequest) -> Result<AI21Request, AiLibError> {
132 let messages = request
134 .messages
135 .iter()
136 .map(|msg| AI21Message {
137 role: match msg.role {
138 Role::System => "system".to_string(),
139 Role::User => "user".to_string(),
140 Role::Assistant => "assistant".to_string(),
141 },
142 content: match &msg.content {
143 crate::Content::Text(text) => text.clone(),
144 _ => "Unsupported content type".to_string(),
145 },
146 })
147 .collect();
148
149 Ok(AI21Request {
150 model: request.model.clone(),
151 messages,
152 max_tokens: request.max_tokens,
153 temperature: request.temperature,
154 top_p: request.top_p,
155 stream: Some(false),
156 extensions: request.extensions.clone(),
157 })
158 }
159
160 fn convert_response(
161 &self,
162 response: AI21Response,
163 ) -> Result<ChatCompletionResponse, AiLibError> {
164 let choice = response.choices.first().ok_or_else(|| {
165 AiLibError::InvalidModelResponse("No choices in AI21 response".to_string())
166 })?;
167
168 let message = Message {
169 role: match choice.message.role.as_str() {
170 "assistant" => Role::Assistant,
171 "user" => Role::User,
172 "system" => Role::System,
173 _ => Role::Assistant,
174 },
175 content: crate::Content::Text(choice.message.content.clone().unwrap_or_default()),
176 function_call: None,
177 };
178
179 Ok(ChatCompletionResponse {
180 id: response.id,
181 object: "chat.completion".to_string(),
182 created: response.created,
183 model: response.model,
184 choices: vec![crate::types::Choice {
185 index: 0,
186 message,
187 finish_reason: choice.finish_reason.clone(),
188 }],
189 usage: response
190 .usage
191 .map(|u| crate::types::Usage {
192 prompt_tokens: u.prompt_tokens,
193 completion_tokens: u.completion_tokens,
194 total_tokens: u.total_tokens,
195 })
196 .unwrap_or_else(|| crate::types::Usage {
197 prompt_tokens: 0,
198 completion_tokens: 0,
199 total_tokens: 0,
200 }),
201 usage_status: crate::types::response::UsageStatus::Finalized,
202 })
203 }
204}
205
206#[async_trait]
207impl ChatProvider for AI21Adapter {
208 fn name(&self) -> &str {
209 "AI21"
210 }
211
212 async fn chat(
213 &self,
214 request: ChatCompletionRequest,
215 ) -> Result<ChatCompletionResponse, AiLibError> {
216 self.chat_completion(request).await
217 }
218
219 async fn stream(
220 &self,
221 request: ChatCompletionRequest,
222 ) -> Result<
223 Box<dyn Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>,
224 AiLibError,
225 > {
226 self.chat_completion_stream(request).await
227 }
228
229 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
230 Ok(vec![
232 "j2-ultra".to_string(),
233 "j2-mid".to_string(),
234 "j2-light".to_string(),
235 ])
236 }
237
238 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
239 Ok(crate::api::ModelInfo {
240 id: model_id.to_string(),
241 object: "model".to_string(),
242 created: 0,
243 owned_by: "ai21".to_string(),
244 permission: vec![],
245 })
246 }
247}
248
249#[derive(Serialize)]
250struct AI21Request {
251 model: String,
252 messages: Vec<AI21Message>,
253 max_tokens: Option<u32>,
254 temperature: Option<f32>,
255 top_p: Option<f32>,
256 stream: Option<bool>,
257 #[serde(flatten)]
258 extensions: Option<serde_json::Map<String, serde_json::Value>>,
259}
260
261#[derive(Serialize)]
262struct AI21Message {
263 role: String,
264 content: String,
265}
266
267#[derive(Deserialize)]
268struct AI21Response {
269 id: String,
270 #[allow(dead_code)]
271 object: String,
272 created: u64,
273 model: String,
274 choices: Vec<AI21Choice>,
275 usage: Option<AI21Usage>,
276}
277
278#[derive(Deserialize)]
279struct AI21Choice {
280 message: AI21MessageResponse,
281 finish_reason: Option<String>,
282}
283
284#[derive(Deserialize)]
285struct AI21MessageResponse {
286 role: String,
287 content: Option<String>,
288}
289
290#[derive(Deserialize)]
291struct AI21Usage {
292 prompt_tokens: u32,
293 completion_tokens: u32,
294 total_tokens: u32,
295}