1use crate::json_utils::merge;
11use crate::providers::openai::send_compatible_streaming_request;
12use crate::streaming::{StreamingCompletionModel, StreamingResult};
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError, CompletionRequest},
16 extractor::ExtractorBuilder,
17 message::{self, AssistantContent, Message, UserContent},
18 OneOrMany,
19};
20use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 content: OneOrMany::one(AssistantContent::Text(message::Text {
65 text: raw.content,
66 })),
67 }),
68 _ => Err(CompletionError::ResponseError(format!(
69 "Unsupported message role: {}",
70 raw.role
71 ))),
72 }
73 }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79 Structured {
80 id: String,
81 object: String,
82 created: u64,
83 model: String,
84 choices: Vec<ChatChoice>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 usage: Option<Usage>,
87 },
88 Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93 pub message: RawMessage,
94 #[serde(default)]
95 pub finish_reason: Option<String>,
96 #[serde(default)]
97 pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102 data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107 id: String,
108}
109
110#[derive(Clone)]
111pub struct Client {
113 base_url: String,
114 client: reqwest::Client,
115 headers: HeaderMap,
116}
117
118impl Client {
119 pub fn new(api_key: &str) -> Result<Self, MiraError> {
121 let mut headers = HeaderMap::new();
122 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123 headers.insert(
124 AUTHORIZATION,
125 HeaderValue::from_str(&format!("Bearer {}", api_key))
126 .map_err(|_| MiraError::InvalidApiKey)?,
127 );
128 headers.insert(
129 reqwest::header::ACCEPT,
130 HeaderValue::from_static("application/json"),
131 );
132 headers.insert(
133 reqwest::header::USER_AGENT,
134 HeaderValue::from_static("rig-client/1.0"),
135 );
136
137 Ok(Self {
138 base_url: MIRA_API_BASE_URL.to_string(),
139 client: reqwest::Client::builder()
140 .build()
141 .expect("Failed to build HTTP client"),
142 headers,
143 })
144 }
145
146 pub fn from_env() -> Result<Self, MiraError> {
149 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
150 Self::new(&api_key)
151 }
152
153 pub fn new_with_base_url(
155 api_key: &str,
156 base_url: impl Into<String>,
157 ) -> Result<Self, MiraError> {
158 let mut client = Self::new(api_key)?;
159 client.base_url = base_url.into();
160 Ok(client)
161 }
162
163 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165 let url = format!("{}/v1/models", self.base_url);
166
167 let response = self
168 .client
169 .get(&url)
170 .headers(self.headers.clone())
171 .send()
172 .await?;
173
174 let status = response.status();
175
176 if !status.is_success() {
177 let _error_text = response.text().await.unwrap_or_default();
179 tracing::error!("Error response: {}", _error_text);
180 return Err(MiraError::ApiError(status.as_u16()));
181 }
182
183 let response_text = response.text().await?;
184
185 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
186 tracing::error!("Failed to parse response: {}", e);
187 MiraError::JsonError(e)
188 })?;
189
190 Ok(models.data.into_iter().map(|model| model.id).collect())
191 }
192
193 pub fn completion_model(&self, model: &str) -> CompletionModel {
195 CompletionModel::new(self.to_owned(), model)
196 }
197
198 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
200 AgentBuilder::new(self.completion_model(model))
201 }
202
203 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
205 &self,
206 model: &str,
207 ) -> ExtractorBuilder<T, CompletionModel> {
208 ExtractorBuilder::new(self.completion_model(model))
209 }
210}
211
212#[derive(Clone)]
213pub struct CompletionModel {
214 client: Client,
215 pub model: String,
217}
218
219impl CompletionModel {
220 pub fn new(client: Client, model: &str) -> Self {
221 Self {
222 client,
223 model: model.to_string(),
224 }
225 }
226
227 fn create_completion_request(
228 &self,
229 completion_request: CompletionRequest,
230 ) -> Result<Value, CompletionError> {
231 let mut messages = Vec::new();
232
233 if let Some(preamble) = &completion_request.preamble {
235 messages.push(serde_json::json!({
236 "role": "user",
237 "content": preamble.to_string()
238 }));
239 }
240
241 messages.push(match &completion_request.prompt {
243 Message::User { content } => {
244 let text = content
245 .iter()
246 .map(|c| match c {
247 UserContent::Text(text) => &text.text,
248 _ => "",
249 })
250 .collect::<Vec<_>>()
251 .join("\n");
252 serde_json::json!({
253 "role": "user",
254 "content": text
255 })
256 }
257 _ => unreachable!(),
258 });
259
260 for msg in completion_request.chat_history {
262 let (role, content) = match msg {
263 Message::User { content } => {
264 let text = content
265 .iter()
266 .map(|c| match c {
267 UserContent::Text(text) => &text.text,
268 _ => "",
269 })
270 .collect::<Vec<_>>()
271 .join("\n");
272 ("user", text)
273 }
274 Message::Assistant { content } => {
275 let text = content
276 .iter()
277 .map(|c| match c {
278 AssistantContent::Text(text) => &text.text,
279 _ => "",
280 })
281 .collect::<Vec<_>>()
282 .join("\n");
283 ("assistant", text)
284 }
285 };
286 messages.push(serde_json::json!({
287 "role": role,
288 "content": content
289 }));
290 }
291
292 let request = serde_json::json!({
293 "model": self.model,
294 "messages": messages,
295 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
296 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
297 "stream": false
298 });
299
300 Ok(request)
301 }
302}
303
304impl completion::CompletionModel for CompletionModel {
305 type Response = CompletionResponse;
306
307 #[cfg_attr(feature = "worker", worker::send)]
308 async fn completion(
309 &self,
310 completion_request: CompletionRequest,
311 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
312 if !completion_request.tools.is_empty() {
313 tracing::warn!(target: "rig",
314 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
315 completion_request.tools.len()
316 );
317 }
318
319 let mira_request = self.create_completion_request(completion_request)?;
320
321 let response = self
322 .client
323 .client
324 .post(format!("{}/v1/chat/completions", self.client.base_url))
325 .headers(self.client.headers.clone())
326 .json(&mira_request)
327 .send()
328 .await
329 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
330
331 if !response.status().is_success() {
332 let status = response.status().as_u16();
333 let error_text = response.text().await.unwrap_or_default();
334 return Err(CompletionError::ProviderError(format!(
335 "API error: {} - {}",
336 status, error_text
337 )));
338 }
339
340 let response: CompletionResponse = response
341 .json()
342 .await
343 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
344
345 response.try_into()
346 }
347}
348
349impl StreamingCompletionModel for CompletionModel {
350 async fn stream(
351 &self,
352 completion_request: CompletionRequest,
353 ) -> Result<StreamingResult, CompletionError> {
354 let mut request = self.create_completion_request(completion_request)?;
355
356 request = merge(request, json!({"stream": true}));
357
358 let builder = self
359 .client
360 .client
361 .post(format!("{}/v1/chat/completions", self.client.base_url))
362 .headers(self.client.headers.clone())
363 .json(&request);
364
365 send_compatible_streaming_request(builder).await
366 }
367}
368
369impl From<ApiErrorResponse> for CompletionError {
370 fn from(err: ApiErrorResponse) -> Self {
371 CompletionError::ProviderError(err.message)
372 }
373}
374
375impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
376 type Error = CompletionError;
377
378 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
379 let content = match &response {
380 CompletionResponse::Structured { choices, .. } => {
381 let choice = choices.first().ok_or_else(|| {
382 CompletionError::ResponseError("Response contained no choices".to_owned())
383 })?;
384
385 let message = message::Message::try_from(choice.message.clone())?;
387
388 match message {
389 Message::Assistant { content } => {
390 if content.is_empty() {
391 return Err(CompletionError::ResponseError(
392 "Response contained empty content".to_owned(),
393 ));
394 }
395
396 for c in content.iter() {
398 if !matches!(c, AssistantContent::Text(_)) {
399 tracing::warn!(target: "rig",
400 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
401 );
402 }
403 }
404
405 content.iter().map(|c| {
406 match c {
407 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
408 other => Err(CompletionError::ResponseError(
409 format!("Unsupported content type: {:?}. The Mira provider currently only supports text content", other)
410 ))
411 }
412 }).collect::<Result<Vec<_>, _>>()?
413 }
414 Message::User { .. } => {
415 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
416 return Err(CompletionError::ResponseError(
417 "Received user message in response where assistant message was expected".to_owned()
418 ));
419 }
420 }
421 }
422 CompletionResponse::Simple(text) => {
423 vec![completion::AssistantContent::text(text)]
424 }
425 };
426
427 let choice = OneOrMany::many(content).map_err(|_| {
428 CompletionError::ResponseError(
429 "Response contained no message or tool call (empty)".to_owned(),
430 )
431 })?;
432
433 Ok(completion::CompletionResponse {
434 choice,
435 raw_response: response,
436 })
437 }
438}
439
440#[derive(Clone, Debug, Deserialize)]
441pub struct Usage {
442 pub prompt_tokens: usize,
443 pub total_tokens: usize,
444}
445
446impl std::fmt::Display for Usage {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 write!(
449 f,
450 "Prompt tokens: {} Total tokens: {}",
451 self.prompt_tokens, self.total_tokens
452 )
453 }
454}
455
456impl From<Message> for serde_json::Value {
457 fn from(msg: Message) -> Self {
458 match msg {
459 Message::User { content } => {
460 let text = content
461 .iter()
462 .map(|c| match c {
463 UserContent::Text(text) => &text.text,
464 _ => "",
465 })
466 .collect::<Vec<_>>()
467 .join("\n");
468 serde_json::json!({
469 "role": "user",
470 "content": text
471 })
472 }
473 Message::Assistant { content } => {
474 let text = content
475 .iter()
476 .map(|c| match c {
477 AssistantContent::Text(text) => &text.text,
478 _ => "",
479 })
480 .collect::<Vec<_>>()
481 .join("\n");
482 serde_json::json!({
483 "role": "assistant",
484 "content": text
485 })
486 }
487 }
488 }
489}
490
491impl TryFrom<serde_json::Value> for Message {
492 type Error = CompletionError;
493
494 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
495 let role = value["role"].as_str().ok_or_else(|| {
496 CompletionError::ResponseError("Message missing role field".to_owned())
497 })?;
498
499 let content = match value.get("content") {
501 Some(content) => match content {
502 serde_json::Value::String(s) => s.clone(),
503 serde_json::Value::Array(arr) => arr
504 .iter()
505 .filter_map(|c| {
506 c.get("text")
507 .and_then(|t| t.as_str())
508 .map(|text| text.to_string())
509 })
510 .collect::<Vec<_>>()
511 .join("\n"),
512 _ => {
513 return Err(CompletionError::ResponseError(
514 "Message content must be string or array".to_owned(),
515 ))
516 }
517 },
518 None => {
519 return Err(CompletionError::ResponseError(
520 "Message missing content field".to_owned(),
521 ))
522 }
523 };
524
525 match role {
526 "user" => Ok(Message::User {
527 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
528 }),
529 "assistant" => Ok(Message::Assistant {
530 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
531 }),
532 _ => Err(CompletionError::ResponseError(format!(
533 "Unsupported message role: {}",
534 role
535 ))),
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::message::UserContent;
544 use serde_json::json;
545
546 #[test]
547 fn test_deserialize_message() {
548 let assistant_message_json = json!({
550 "role": "assistant",
551 "content": "Hello there, how may I assist you today?"
552 });
553
554 let user_message_json = json!({
555 "role": "user",
556 "content": "What can you help me with?"
557 });
558
559 let assistant_message_array_json = json!({
561 "role": "assistant",
562 "content": [{
563 "type": "text",
564 "text": "Hello there, how may I assist you today?"
565 }]
566 });
567
568 let assistant_message = Message::try_from(assistant_message_json).unwrap();
569 let user_message = Message::try_from(user_message_json).unwrap();
570 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
571
572 match assistant_message {
574 Message::Assistant { content } => {
575 assert_eq!(
576 content.first(),
577 AssistantContent::Text(message::Text {
578 text: "Hello there, how may I assist you today?".to_string()
579 })
580 );
581 }
582 _ => panic!("Expected assistant message"),
583 }
584
585 match user_message {
586 Message::User { content } => {
587 assert_eq!(
588 content.first(),
589 UserContent::Text(message::Text {
590 text: "What can you help me with?".to_string()
591 })
592 );
593 }
594 _ => panic!("Expected user message"),
595 }
596
597 match assistant_message_array {
599 Message::Assistant { content } => {
600 assert_eq!(
601 content.first(),
602 AssistantContent::Text(message::Text {
603 text: "Hello there, how may I assist you today?".to_string()
604 })
605 );
606 }
607 _ => panic!("Expected assistant message"),
608 }
609 }
610
611 #[test]
612 fn test_message_conversion() {
613 let original_message = message::Message::User {
615 content: OneOrMany::one(message::UserContent::text("Hello")),
616 };
617
618 let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
620
621 let converted_message: Message = mira_value.try_into().unwrap();
623
624 let final_message: message::Message = converted_message.try_into().unwrap();
626
627 assert_eq!(original_message, final_message);
628 }
629
630 #[test]
631 fn test_completion_response_conversion() {
632 let mira_response = CompletionResponse::Structured {
633 id: "resp_123".to_string(),
634 object: "chat.completion".to_string(),
635 created: 1234567890,
636 model: "deepseek-r1".to_string(),
637 choices: vec![ChatChoice {
638 message: RawMessage {
639 role: "assistant".to_string(),
640 content: "Test response".to_string(),
641 },
642 finish_reason: Some("stop".to_string()),
643 index: Some(0),
644 }],
645 usage: Some(Usage {
646 prompt_tokens: 10,
647 total_tokens: 20,
648 }),
649 };
650
651 let completion_response: completion::CompletionResponse<CompletionResponse> =
652 mira_response.try_into().unwrap();
653
654 assert_eq!(
655 completion_response.choice.first(),
656 completion::AssistantContent::text("Test response")
657 );
658 }
659}