1use crate::api::ChatApi;
2use crate::types::{
3 ChatCompletionRequest, ChatCompletionResponse, Message, Role,
4};
5use crate::types::AiLibError;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use futures::Stream;
9use async_trait::async_trait;
10
11pub struct AI21Adapter {
16 client: Client,
17 api_key: String,
18}
19
20impl AI21Adapter {
21 pub fn new() -> Result<Self, AiLibError> {
22 let api_key = std::env::var("AI21_API_KEY")
23 .map_err(|_| AiLibError::ConfigurationError(
24 "AI21_API_KEY environment variable not set".to_string()
25 ))?;
26
27 Ok(Self {
28 client: Client::new(),
29 api_key,
30 })
31 }
32
33 pub async fn chat_completion(
34 &self,
35 request: ChatCompletionRequest,
36 ) -> Result<ChatCompletionResponse, AiLibError> {
37 let ai21_request = self.convert_request(&request)?;
38
39 let response = self
40 .client
41 .post("https://api.ai21.com/studio/v1/chat/completions")
42 .header("Authorization", format!("Bearer {}", self.api_key))
43 .header("Content-Type", "application/json")
44 .json(&ai21_request)
45 .send()
46 .await
47 .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
48
49 if !response.status().is_success() {
50 let status = response.status();
51 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
52 return Err(AiLibError::ProviderError(format!(
53 "AI21 API error {}: {}",
54 status, error_text
55 )));
56 }
57
58 let ai21_response: AI21Response = response
59 .json()
60 .await
61 .map_err(|e| AiLibError::DeserializationError(format!("Failed to parse AI21 response: {}", e)))?;
62
63 self.convert_response(ai21_response)
64 }
65
66 pub async fn chat_completion_stream(
67 &self,
68 request: ChatCompletionRequest,
69 ) -> Result<Box<dyn futures::Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
70 let mut ai21_request = self.convert_request(&request)?;
71 ai21_request.stream = Some(true);
72
73 let response = self
74 .client
75 .post("https://api.ai21.com/studio/v1/chat/completions")
76 .header("Authorization", format!("Bearer {}", self.api_key))
77 .header("Content-Type", "application/json")
78 .json(&ai21_request)
79 .send()
80 .await
81 .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
82
83 if !response.status().is_success() {
84 let status = response.status();
85 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
86 return Err(AiLibError::ProviderError(format!(
87 "AI21 API error {}: {}",
88 status, error_text
89 )));
90 }
91
92 let response = self.chat_completion(request.clone()).await?;
94
95 let chunk = crate::api::ChatCompletionChunk {
97 id: response.id.clone(),
98 object: "chat.completion.chunk".to_string(),
99 created: response.created,
100 model: response.model.clone(),
101 choices: response.choices.into_iter().map(|choice| {
102 crate::api::ChoiceDelta {
103 index: choice.index,
104 delta: crate::api::MessageDelta {
105 role: Some(choice.message.role),
106 content: Some(match &choice.message.content {
107 crate::Content::Text(text) => text.clone(),
108 _ => "".to_string(),
109 }),
110 },
111 finish_reason: choice.finish_reason,
112 }
113 }).collect(),
114 };
115
116 let stream = futures::stream::once(async move { Ok(chunk) });
117 Ok(Box::new(Box::pin(stream)))
118 }
119
120 fn convert_request(&self, request: &ChatCompletionRequest) -> Result<AI21Request, AiLibError> {
121 let messages = request
123 .messages
124 .iter()
125 .map(|msg| AI21Message {
126 role: match msg.role {
127 Role::System => "system".to_string(),
128 Role::User => "user".to_string(),
129 Role::Assistant => "assistant".to_string(),
130 },
131 content: match &msg.content {
132 crate::Content::Text(text) => text.clone(),
133 _ => "Unsupported content type".to_string(),
134 },
135 })
136 .collect();
137
138 Ok(AI21Request {
139 model: request.model.clone(),
140 messages,
141 max_tokens: request.max_tokens,
142 temperature: request.temperature,
143 top_p: request.top_p,
144 stream: Some(false),
145 })
146 }
147
148 fn convert_response(&self, response: AI21Response) -> Result<ChatCompletionResponse, AiLibError> {
149 let choice = response.choices.first()
150 .ok_or_else(|| AiLibError::InvalidModelResponse("No choices in AI21 response".to_string()))?;
151
152 let message = Message {
153 role: match choice.message.role.as_str() {
154 "assistant" => Role::Assistant,
155 "user" => Role::User,
156 "system" => Role::System,
157 _ => Role::Assistant,
158 },
159 content: crate::Content::Text(choice.message.content.clone().unwrap_or_default()),
160 function_call: None,
161 };
162
163 Ok(ChatCompletionResponse {
164 id: response.id,
165 object: "chat.completion".to_string(),
166 created: response.created,
167 model: response.model,
168 choices: vec![crate::types::Choice {
169 index: 0,
170 message,
171 finish_reason: choice.finish_reason.clone(),
172 }],
173 usage: response.usage.map(|u| crate::types::Usage {
174 prompt_tokens: u.prompt_tokens,
175 completion_tokens: u.completion_tokens,
176 total_tokens: u.total_tokens,
177 }).unwrap_or_else(|| crate::types::Usage {
178 prompt_tokens: 0,
179 completion_tokens: 0,
180 total_tokens: 0,
181 }),
182 usage_status: crate::types::response::UsageStatus::Finalized,
183 })
184 }
185}
186
187#[async_trait]
188impl ChatApi for AI21Adapter {
189 async fn chat_completion(
190 &self,
191 request: ChatCompletionRequest,
192 ) -> Result<ChatCompletionResponse, AiLibError> {
193 self.chat_completion(request).await
194 }
195
196 async fn chat_completion_stream(
197 &self,
198 request: ChatCompletionRequest,
199 ) -> Result<
200 Box<dyn Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>,
201 AiLibError,
202 > {
203 self.chat_completion_stream(request).await
204 }
205
206 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
207 Ok(vec![
209 "j2-ultra".to_string(),
210 "j2-mid".to_string(),
211 "j2-light".to_string(),
212 ])
213 }
214
215 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
216 Ok(crate::api::ModelInfo {
217 id: model_id.to_string(),
218 object: "model".to_string(),
219 created: 0,
220 owned_by: "ai21".to_string(),
221 permission: vec![],
222 })
223 }
224}
225
226#[derive(Serialize)]
227struct AI21Request {
228 model: String,
229 messages: Vec<AI21Message>,
230 max_tokens: Option<u32>,
231 temperature: Option<f32>,
232 top_p: Option<f32>,
233 stream: Option<bool>,
234}
235
236#[derive(Serialize)]
237struct AI21Message {
238 role: String,
239 content: String,
240}
241
242#[derive(Deserialize)]
243struct AI21Response {
244 id: String,
245 #[allow(dead_code)]
246 object: String,
247 created: u64,
248 model: String,
249 choices: Vec<AI21Choice>,
250 usage: Option<AI21Usage>,
251}
252
253#[derive(Deserialize)]
254struct AI21Choice {
255 message: AI21MessageResponse,
256 finish_reason: Option<String>,
257}
258
259#[derive(Deserialize)]
260struct AI21MessageResponse {
261 role: String,
262 content: Option<String>,
263}
264
265#[derive(Deserialize)]
266struct AI21Usage {
267 prompt_tokens: u32,
268 completion_tokens: u32,
269 total_tokens: u32,
270}