ai_lib/provider/
ai21.rs

1use crate::api::ChatApi;
2use crate::types::{
3    ChatCompletionRequest, ChatCompletionResponse, Message, Role,
4};
5use crate::types::AiLibError;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use futures::Stream;
9use async_trait::async_trait;
10
11/// AI21 API adapter
12/// 
13/// AI21 provides Jurassic series models with a custom API format.
14/// Documentation: https://docs.ai21.com/reference/introduction
15pub struct AI21Adapter {
16    client: Client,
17    api_key: String,
18}
19
20impl AI21Adapter {
21    pub fn new() -> Result<Self, AiLibError> {
22        let api_key = std::env::var("AI21_API_KEY")
23            .map_err(|_| AiLibError::ConfigurationError(
24                "AI21_API_KEY environment variable not set".to_string()
25            ))?;
26
27        Ok(Self {
28            client: Client::new(),
29            api_key,
30        })
31    }
32
33    pub async fn chat_completion(
34        &self,
35        request: ChatCompletionRequest,
36    ) -> Result<ChatCompletionResponse, AiLibError> {
37        let ai21_request = self.convert_request(&request)?;
38        
39        let response = self
40            .client
41            .post("https://api.ai21.com/studio/v1/chat/completions")
42            .header("Authorization", format!("Bearer {}", self.api_key))
43            .header("Content-Type", "application/json")
44            .json(&ai21_request)
45            .send()
46            .await
47            .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
48
49        if !response.status().is_success() {
50            let status = response.status();
51            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
52            return Err(AiLibError::ProviderError(format!(
53                "AI21 API error {}: {}",
54                status, error_text
55            )));
56        }
57
58        let ai21_response: AI21Response = response
59            .json()
60            .await
61            .map_err(|e| AiLibError::DeserializationError(format!("Failed to parse AI21 response: {}", e)))?;
62
63        self.convert_response(ai21_response)
64    }
65
66    pub async fn chat_completion_stream(
67        &self,
68        request: ChatCompletionRequest,
69    ) -> Result<Box<dyn futures::Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
70        let mut ai21_request = self.convert_request(&request)?;
71        ai21_request.stream = Some(true);
72
73        let response = self
74            .client
75            .post("https://api.ai21.com/studio/v1/chat/completions")
76            .header("Authorization", format!("Bearer {}", self.api_key))
77            .header("Content-Type", "application/json")
78            .json(&ai21_request)
79            .send()
80            .await
81            .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
82
83        if !response.status().is_success() {
84            let status = response.status();
85            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
86            return Err(AiLibError::ProviderError(format!(
87                "AI21 API error {}: {}",
88                status, error_text
89            )));
90        }
91
92        // For now, convert streaming request to non-streaming and return a single chunk
93        let response = self.chat_completion(request.clone()).await?;
94        
95        // Create a single chunk from the response
96        let chunk = crate::api::ChatCompletionChunk {
97            id: response.id.clone(),
98            object: "chat.completion.chunk".to_string(),
99            created: response.created,
100            model: response.model.clone(),
101            choices: response.choices.into_iter().map(|choice| {
102                crate::api::ChoiceDelta {
103                    index: choice.index,
104                    delta: crate::api::MessageDelta {
105                        role: Some(choice.message.role),
106                        content: Some(match &choice.message.content {
107                            crate::Content::Text(text) => text.clone(),
108                            _ => "".to_string(),
109                        }),
110                    },
111                    finish_reason: choice.finish_reason,
112                }
113            }).collect(),
114        };
115        
116        let stream = futures::stream::once(async move { Ok(chunk) });
117        Ok(Box::new(Box::pin(stream)))
118    }
119
120    fn convert_request(&self, request: &ChatCompletionRequest) -> Result<AI21Request, AiLibError> {
121        // Convert messages to AI21 format
122        let messages = request
123            .messages
124            .iter()
125            .map(|msg| AI21Message {
126                role: match msg.role {
127                    Role::System => "system".to_string(),
128                    Role::User => "user".to_string(),
129                    Role::Assistant => "assistant".to_string(),
130                },
131                content: match &msg.content {
132                    crate::Content::Text(text) => text.clone(),
133                    _ => "Unsupported content type".to_string(),
134                },
135            })
136            .collect();
137
138        Ok(AI21Request {
139            model: request.model.clone(),
140            messages,
141            max_tokens: request.max_tokens,
142            temperature: request.temperature,
143            top_p: request.top_p,
144            stream: Some(false),
145        })
146    }
147
148    fn convert_response(&self, response: AI21Response) -> Result<ChatCompletionResponse, AiLibError> {
149        let choice = response.choices.first()
150            .ok_or_else(|| AiLibError::InvalidModelResponse("No choices in AI21 response".to_string()))?;
151
152        let message = Message {
153            role: match choice.message.role.as_str() {
154                "assistant" => Role::Assistant,
155                "user" => Role::User,
156                "system" => Role::System,
157                _ => Role::Assistant,
158            },
159            content: crate::Content::Text(choice.message.content.clone().unwrap_or_default()),
160            function_call: None,
161        };
162
163        Ok(ChatCompletionResponse {
164            id: response.id,
165            object: "chat.completion".to_string(),
166            created: response.created,
167            model: response.model,
168            choices: vec![crate::types::Choice {
169                index: 0,
170                message,
171                finish_reason: choice.finish_reason.clone(),
172            }],
173            usage: response.usage.map(|u| crate::types::Usage {
174                prompt_tokens: u.prompt_tokens,
175                completion_tokens: u.completion_tokens,
176                total_tokens: u.total_tokens,
177            }).unwrap_or_else(|| crate::types::Usage {
178                prompt_tokens: 0,
179                completion_tokens: 0,
180                total_tokens: 0,
181            }),
182            usage_status: crate::types::response::UsageStatus::Finalized,
183        })
184    }
185}
186
187#[async_trait]
188impl ChatApi for AI21Adapter {
189    async fn chat_completion(
190        &self,
191        request: ChatCompletionRequest,
192    ) -> Result<ChatCompletionResponse, AiLibError> {
193        self.chat_completion(request).await
194    }
195
196    async fn chat_completion_stream(
197        &self,
198        request: ChatCompletionRequest,
199    ) -> Result<
200        Box<dyn Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>,
201        AiLibError,
202    > {
203        self.chat_completion_stream(request).await
204    }
205
206    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
207        // Return default models for AI21
208        Ok(vec![
209            "j2-ultra".to_string(),
210            "j2-mid".to_string(),
211            "j2-light".to_string(),
212        ])
213    }
214
215    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
216        Ok(crate::api::ModelInfo {
217            id: model_id.to_string(),
218            object: "model".to_string(),
219            created: 0,
220            owned_by: "ai21".to_string(),
221            permission: vec![],
222        })
223    }
224}
225
226#[derive(Serialize)]
227struct AI21Request {
228    model: String,
229    messages: Vec<AI21Message>,
230    max_tokens: Option<u32>,
231    temperature: Option<f32>,
232    top_p: Option<f32>,
233    stream: Option<bool>,
234}
235
236#[derive(Serialize)]
237struct AI21Message {
238    role: String,
239    content: String,
240}
241
242#[derive(Deserialize)]
243struct AI21Response {
244    id: String,
245    #[allow(dead_code)]
246    object: String,
247    created: u64,
248    model: String,
249    choices: Vec<AI21Choice>,
250    usage: Option<AI21Usage>,
251}
252
253#[derive(Deserialize)]
254struct AI21Choice {
255    message: AI21MessageResponse,
256    finish_reason: Option<String>,
257}
258
259#[derive(Deserialize)]
260struct AI21MessageResponse {
261    role: String,
262    content: Option<String>,
263}
264
265#[derive(Deserialize)]
266struct AI21Usage {
267    prompt_tokens: u32,
268    completion_tokens: u32,
269    total_tokens: u32,
270}