ai_lib/provider/
ai21.rs

1use crate::api::ChatProvider;
2use crate::types::AiLibError;
3use crate::types::{ChatCompletionRequest, ChatCompletionResponse, Message, Role};
4use async_trait::async_trait;
5use futures::Stream;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8
9/// AI21 API adapter
10///
11/// AI21 provides Jurassic series models with a custom API format.
12/// Documentation: <https://docs.ai21.com/reference/introduction>
13pub struct AI21Adapter {
14    client: Client,
15    api_key: String,
16}
17
18impl AI21Adapter {
19    pub fn new() -> Result<Self, AiLibError> {
20        let api_key = std::env::var("AI21_API_KEY").map_err(|_| {
21            AiLibError::ConfigurationError("AI21_API_KEY environment variable not set".to_string())
22        })?;
23
24        Ok(Self {
25            client: Client::new(),
26            api_key,
27        })
28    }
29
30    pub async fn chat_completion(
31        &self,
32        request: ChatCompletionRequest,
33    ) -> Result<ChatCompletionResponse, AiLibError> {
34        let ai21_request = self.convert_request(&request)?;
35
36        let response = self
37            .client
38            .post("https://api.ai21.com/studio/v1/chat/completions")
39            .header("Authorization", format!("Bearer {}", self.api_key))
40            .header("Content-Type", "application/json")
41            .json(&ai21_request)
42            .send()
43            .await
44            .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
45
46        if !response.status().is_success() {
47            let status = response.status();
48            let error_text = response
49                .text()
50                .await
51                .unwrap_or_else(|_| "Unknown error".to_string());
52            return Err(AiLibError::ProviderError(format!(
53                "AI21 API error {}: {}",
54                status, error_text
55            )));
56        }
57
58        let ai21_response: AI21Response = response.json().await.map_err(|e| {
59            AiLibError::DeserializationError(format!("Failed to parse AI21 response: {}", e))
60        })?;
61
62        self.convert_response(ai21_response)
63    }
64
65    pub async fn chat_completion_stream(
66        &self,
67        request: ChatCompletionRequest,
68    ) -> Result<
69        Box<
70            dyn futures::Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>>
71                + Send
72                + Unpin,
73        >,
74        AiLibError,
75    > {
76        let mut ai21_request = self.convert_request(&request)?;
77        ai21_request.stream = Some(true);
78
79        let response = self
80            .client
81            .post("https://api.ai21.com/studio/v1/chat/completions")
82            .header("Authorization", format!("Bearer {}", self.api_key))
83            .header("Content-Type", "application/json")
84            .json(&ai21_request)
85            .send()
86            .await
87            .map_err(|e| AiLibError::NetworkError(format!("AI21 API request failed: {}", e)))?;
88
89        if !response.status().is_success() {
90            let status = response.status();
91            let error_text = response
92                .text()
93                .await
94                .unwrap_or_else(|_| "Unknown error".to_string());
95            return Err(AiLibError::ProviderError(format!(
96                "AI21 API error {}: {}",
97                status, error_text
98            )));
99        }
100
101        // For now, convert streaming request to non-streaming and return a single chunk
102        let response = self.chat_completion(request.clone()).await?;
103
104        // Create a single chunk from the response
105        let chunk = crate::api::ChatCompletionChunk {
106            id: response.id.clone(),
107            object: "chat.completion.chunk".to_string(),
108            created: response.created,
109            model: response.model.clone(),
110            choices: response
111                .choices
112                .into_iter()
113                .map(|choice| crate::api::ChoiceDelta {
114                    index: choice.index,
115                    delta: crate::api::MessageDelta {
116                        role: Some(choice.message.role),
117                        content: Some(match &choice.message.content {
118                            crate::Content::Text(text) => text.clone(),
119                            _ => "".to_string(),
120                        }),
121                    },
122                    finish_reason: choice.finish_reason,
123                })
124                .collect(),
125        };
126
127        let stream = futures::stream::once(async move { Ok(chunk) });
128        Ok(Box::new(Box::pin(stream)))
129    }
130
131    fn convert_request(&self, request: &ChatCompletionRequest) -> Result<AI21Request, AiLibError> {
132        // Convert messages to AI21 format
133        let messages = request
134            .messages
135            .iter()
136            .map(|msg| AI21Message {
137                role: match msg.role {
138                    Role::System => "system".to_string(),
139                    Role::User => "user".to_string(),
140                    Role::Assistant => "assistant".to_string(),
141                },
142                content: match &msg.content {
143                    crate::Content::Text(text) => text.clone(),
144                    _ => "Unsupported content type".to_string(),
145                },
146            })
147            .collect();
148
149        Ok(AI21Request {
150            model: request.model.clone(),
151            messages,
152            max_tokens: request.max_tokens,
153            temperature: request.temperature,
154            top_p: request.top_p,
155            stream: Some(false),
156            extensions: request.extensions.clone(),
157        })
158    }
159
160    fn convert_response(
161        &self,
162        response: AI21Response,
163    ) -> Result<ChatCompletionResponse, AiLibError> {
164        let choice = response.choices.first().ok_or_else(|| {
165            AiLibError::InvalidModelResponse("No choices in AI21 response".to_string())
166        })?;
167
168        let message = Message {
169            role: match choice.message.role.as_str() {
170                "assistant" => Role::Assistant,
171                "user" => Role::User,
172                "system" => Role::System,
173                _ => Role::Assistant,
174            },
175            content: crate::Content::Text(choice.message.content.clone().unwrap_or_default()),
176            function_call: None,
177        };
178
179        Ok(ChatCompletionResponse {
180            id: response.id,
181            object: "chat.completion".to_string(),
182            created: response.created,
183            model: response.model,
184            choices: vec![crate::types::Choice {
185                index: 0,
186                message,
187                finish_reason: choice.finish_reason.clone(),
188            }],
189            usage: response
190                .usage
191                .map(|u| crate::types::Usage {
192                    prompt_tokens: u.prompt_tokens,
193                    completion_tokens: u.completion_tokens,
194                    total_tokens: u.total_tokens,
195                })
196                .unwrap_or_else(|| crate::types::Usage {
197                    prompt_tokens: 0,
198                    completion_tokens: 0,
199                    total_tokens: 0,
200                }),
201            usage_status: crate::types::response::UsageStatus::Finalized,
202        })
203    }
204}
205
206#[async_trait]
207impl ChatProvider for AI21Adapter {
208    fn name(&self) -> &str {
209        "AI21"
210    }
211
212    async fn chat(
213        &self,
214        request: ChatCompletionRequest,
215    ) -> Result<ChatCompletionResponse, AiLibError> {
216        self.chat_completion(request).await
217    }
218
219    async fn stream(
220        &self,
221        request: ChatCompletionRequest,
222    ) -> Result<
223        Box<dyn Stream<Item = Result<crate::api::ChatCompletionChunk, AiLibError>> + Send + Unpin>,
224        AiLibError,
225    > {
226        self.chat_completion_stream(request).await
227    }
228
229    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
230        // Return default models for AI21
231        Ok(vec![
232            "j2-ultra".to_string(),
233            "j2-mid".to_string(),
234            "j2-light".to_string(),
235        ])
236    }
237
238    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
239        Ok(crate::api::ModelInfo {
240            id: model_id.to_string(),
241            object: "model".to_string(),
242            created: 0,
243            owned_by: "ai21".to_string(),
244            permission: vec![],
245        })
246    }
247}
248
249#[derive(Serialize)]
250struct AI21Request {
251    model: String,
252    messages: Vec<AI21Message>,
253    max_tokens: Option<u32>,
254    temperature: Option<f32>,
255    top_p: Option<f32>,
256    stream: Option<bool>,
257    #[serde(flatten)]
258    extensions: Option<serde_json::Map<String, serde_json::Value>>,
259}
260
261#[derive(Serialize)]
262struct AI21Message {
263    role: String,
264    content: String,
265}
266
267#[derive(Deserialize)]
268struct AI21Response {
269    id: String,
270    #[allow(dead_code)]
271    object: String,
272    created: u64,
273    model: String,
274    choices: Vec<AI21Choice>,
275    usage: Option<AI21Usage>,
276}
277
278#[derive(Deserialize)]
279struct AI21Choice {
280    message: AI21MessageResponse,
281    finish_reason: Option<String>,
282}
283
284#[derive(Deserialize)]
285struct AI21MessageResponse {
286    role: String,
287    content: Option<String>,
288}
289
290#[derive(Deserialize)]
291struct AI21Usage {
292    prompt_tokens: u32,
293    completion_tokens: u32,
294    total_tokens: u32,
295}