1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58 type Extension<H>
59 = MiraExt
60 where
61 H: HttpClientExt;
62 type ApiKey = MiraApiKey;
63
64 const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66 fn build<H>(
67 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68 ) -> http_client::Result<Self::Extension<H>>
69 where
70 H: HttpClientExt,
71 {
72 Ok(MiraExt)
73 }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = crate::markers::Missing> =
78 client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
79
80#[derive(Debug, Error)]
81pub enum MiraError {
82 #[error("Invalid API key")]
83 InvalidApiKey,
84 #[error("API error: {0}")]
85 ApiError(u16),
86 #[error("Request error: {0}")]
87 RequestError(#[from] http_client::Error),
88 #[error("UTF-8 error: {0}")]
89 Utf8Error(#[from] FromUtf8Error),
90 #[error("JSON error: {0}")]
91 JsonError(#[from] serde_json::Error),
92}
93
94#[derive(Debug, Deserialize)]
95struct ApiErrorResponse {
96 message: String,
97}
98
99#[derive(Debug, Deserialize, Clone, Serialize)]
100pub struct RawMessage {
101 pub role: String,
102 pub content: String,
103}
104
105const MIRA_API_BASE_URL: &str = "https://api.mira.network";
106
107impl TryFrom<RawMessage> for message::Message {
108 type Error = CompletionError;
109
110 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
111 match raw.role.as_str() {
112 "system" => Ok(message::Message::System {
113 content: raw.content,
114 }),
115 "user" => Ok(message::Message::User {
116 content: OneOrMany::one(UserContent::Text(message::Text::new(raw.content))),
117 }),
118 "assistant" => Ok(message::Message::Assistant {
119 id: None,
120 content: OneOrMany::one(AssistantContent::Text(message::Text::new(raw.content))),
121 }),
122 _ => Err(CompletionError::ResponseError(format!(
123 "Unsupported message role: {}",
124 raw.role
125 ))),
126 }
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131#[serde(untagged)]
132pub enum CompletionResponse {
133 Structured {
134 id: String,
135 object: String,
136 created: u64,
137 model: String,
138 choices: Vec<ChatChoice>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 usage: Option<Usage>,
141 },
142 Simple(String),
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub struct ChatChoice {
147 pub message: RawMessage,
148 #[serde(default)]
149 pub finish_reason: Option<String>,
150 #[serde(default)]
151 pub index: Option<usize>,
152}
153
154#[derive(Debug, Deserialize, Serialize)]
155struct ModelsResponse {
156 data: Vec<ModelInfo>,
157}
158
159#[derive(Debug, Deserialize, Serialize)]
160struct ModelInfo {
161 id: String,
162}
163
164impl<T> Client<T>
165where
166 T: HttpClientExt + 'static,
167{
168 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
170 let req = self.get("/v1/models").and_then(|req| {
171 req.body(http_client::NoBody)
172 .map_err(http_client::Error::Protocol)
173 })?;
174
175 let response = self.send(req).await?;
176
177 let status = response.status();
178
179 if !status.is_success() {
180 let error_text = http_client::text(response).await.unwrap_or_default();
182 tracing::error!("Error response: {}", error_text);
183 return Err(MiraError::ApiError(status.as_u16()));
184 }
185
186 let response_text = http_client::text(response).await?;
187
188 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
189 tracing::error!("Failed to parse response: {}", e);
190 MiraError::JsonError(e)
191 })?;
192
193 Ok(models.data.into_iter().map(|model| model.id).collect())
194 }
195}
196
197impl ProviderClient for Client {
198 type Input = String;
199 type Error = crate::client::ProviderClientError;
200
201 fn from_env() -> Result<Self, Self::Error> {
203 let api_key = crate::client::required_env_var("MIRA_API_KEY")?;
204 Self::new(&api_key).map_err(Into::into)
205 }
206
207 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
208 Self::new(&input).map_err(Into::into)
209 }
210}
211
212#[derive(Debug, Serialize, Deserialize)]
213pub(super) struct MiraCompletionRequest {
214 model: String,
215 pub messages: Vec<RawMessage>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 temperature: Option<f64>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 max_tokens: Option<u64>,
220 pub stream: bool,
221}
222
223impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
224 type Error = CompletionError;
225
226 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
227 if req.output_schema.is_some() {
228 tracing::warn!("Structured outputs currently not supported for Mira");
229 }
230 let model = req.model.clone().unwrap_or_else(|| model.to_string());
231 let mut messages = Vec::new();
232
233 if let Some(content) = &req.preamble {
234 messages.push(RawMessage {
235 role: "user".to_string(),
236 content: content.to_string(),
237 });
238 }
239
240 if let Some(Message::User { content }) = req.normalized_documents() {
241 let text = content
242 .into_iter()
243 .filter_map(|doc| match doc {
244 UserContent::Document(Document {
245 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
246 ..
247 }) => Some(data),
248 UserContent::Text(text) => Some(text.text),
249
250 _ => None,
252 })
253 .collect::<Vec<_>>()
254 .join("\n");
255
256 messages.push(RawMessage {
257 role: "user".to_string(),
258 content: text,
259 });
260 }
261
262 for msg in req.chat_history {
263 let (role, content) = match msg {
264 Message::System { content } => ("system", content),
265 Message::User { content } => {
266 let text = content
267 .iter()
268 .map(|c| match c {
269 UserContent::Text(text) => &text.text,
270 _ => "",
271 })
272 .collect::<Vec<_>>()
273 .join("\n");
274 ("user", text)
275 }
276 Message::Assistant { content, .. } => {
277 let text = content
278 .iter()
279 .map(|c| match c {
280 AssistantContent::Text(text) => &text.text,
281 _ => "",
282 })
283 .collect::<Vec<_>>()
284 .join("\n");
285 ("assistant", text)
286 }
287 };
288 messages.push(RawMessage {
289 role: role.to_string(),
290 content,
291 });
292 }
293
294 Ok(Self {
295 model: model.to_string(),
296 messages,
297 temperature: req.temperature,
298 max_tokens: req.max_tokens,
299 stream: false,
300 })
301 }
302}
303
304#[derive(Clone)]
305pub struct CompletionModel<T = reqwest::Client> {
306 client: Client<T>,
307 pub model: String,
309}
310
311impl<T> CompletionModel<T> {
312 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
313 Self {
314 client,
315 model: model.into(),
316 }
317 }
318}
319
320impl<T> completion::CompletionModel for CompletionModel<T>
321where
322 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
323{
324 type Response = CompletionResponse;
325 type StreamingResponse = openai::StreamingCompletionResponse;
326
327 type Client = Client<T>;
328
329 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
330 Self::new(client.clone(), model)
331 }
332
333 async fn completion(
334 &self,
335 completion_request: CompletionRequest,
336 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
337 let span = if tracing::Span::current().is_disabled() {
338 info_span!(
339 target: "rig::completions",
340 "chat",
341 gen_ai.operation.name = "chat",
342 gen_ai.provider.name = "mira",
343 gen_ai.request.model = self.model,
344 gen_ai.system_instructions = tracing::field::Empty,
345 gen_ai.response.id = tracing::field::Empty,
346 gen_ai.response.model = tracing::field::Empty,
347 gen_ai.usage.output_tokens = tracing::field::Empty,
348 gen_ai.usage.input_tokens = tracing::field::Empty,
349 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
350 )
351 } else {
352 tracing::Span::current()
353 };
354
355 span.record("gen_ai.system_instructions", &completion_request.preamble);
356
357 if !completion_request.tools.is_empty() {
358 tracing::warn!(target: "rig::completions",
359 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
360 len = completion_request.tools.len()
361 );
362 }
363
364 if completion_request.tool_choice.is_some() {
365 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
366 }
367
368 if completion_request.additional_params.is_some() {
369 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
370 }
371
372 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
373
374 if tracing::enabled!(tracing::Level::TRACE) {
375 tracing::trace!(target: "rig::completions",
376 "Mira completion request: {}",
377 serde_json::to_string_pretty(&request)?
378 );
379 }
380
381 let body = serde_json::to_vec(&request)?;
382
383 let req = self
384 .client
385 .post("/v1/chat/completions")?
386 .body(body)
387 .map_err(http_client::Error::from)?;
388
389 let async_block = async move {
390 let response = self
391 .client
392 .send::<_, bytes::Bytes>(req)
393 .await
394 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
395
396 let status = response.status();
397 let response_body = response.into_body().into_future().await?.to_vec();
398
399 if !status.is_success() {
400 let status = status.as_u16();
401 let error_text = String::from_utf8_lossy(&response_body).to_string();
402 return Err(CompletionError::ProviderError(format!(
403 "API error: {status} - {error_text}"
404 )));
405 }
406
407 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
408
409 if tracing::enabled!(tracing::Level::TRACE) {
410 tracing::trace!(target: "rig::completions",
411 "Mira completion response: {}",
412 serde_json::to_string_pretty(&response)?
413 );
414 }
415
416 if let CompletionResponse::Structured {
417 id, model, usage, ..
418 } = &response
419 {
420 let span = tracing::Span::current();
421 span.record("gen_ai.response.model", model);
422 span.record("gen_ai.response.id", id);
423 if let Some(usage) = usage {
424 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
425 span.record(
426 "gen_ai.usage.output_tokens",
427 usage.total_tokens - usage.prompt_tokens,
428 );
429 }
430 }
431
432 response.try_into()
433 };
434
435 async_block.instrument(span).await
436 }
437
438 async fn stream(
439 &self,
440 completion_request: CompletionRequest,
441 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
442 let span = if tracing::Span::current().is_disabled() {
443 info_span!(
444 target: "rig::completions",
445 "chat_streaming",
446 gen_ai.operation.name = "chat_streaming",
447 gen_ai.provider.name = "mira",
448 gen_ai.request.model = self.model,
449 gen_ai.system_instructions = tracing::field::Empty,
450 gen_ai.response.id = tracing::field::Empty,
451 gen_ai.response.model = tracing::field::Empty,
452 gen_ai.usage.output_tokens = tracing::field::Empty,
453 gen_ai.usage.input_tokens = tracing::field::Empty,
454 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
455 )
456 } else {
457 tracing::Span::current()
458 };
459
460 span.record("gen_ai.system_instructions", &completion_request.preamble);
461
462 if !completion_request.tools.is_empty() {
463 tracing::warn!(target: "rig::completions",
464 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
465 len = completion_request.tools.len()
466 );
467 }
468
469 if completion_request.tool_choice.is_some() {
470 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
471 }
472
473 if completion_request.additional_params.is_some() {
474 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
475 }
476 let mut request =
477 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
478 request.stream = true;
479
480 if tracing::enabled!(tracing::Level::TRACE) {
481 tracing::trace!(target: "rig::completions",
482 "Mira completion request: {}",
483 serde_json::to_string_pretty(&request)?
484 );
485 }
486
487 let body = serde_json::to_vec(&request)?;
488
489 let req = self
490 .client
491 .post("/v1/chat/completions")?
492 .body(body)
493 .map_err(http_client::Error::from)?;
494
495 send_compatible_streaming_request(self.client.clone(), req)
496 .instrument(span)
497 .await
498 }
499}
500
501impl From<ApiErrorResponse> for CompletionError {
502 fn from(err: ApiErrorResponse) -> Self {
503 CompletionError::ProviderError(err.message)
504 }
505}
506
507impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
508 type Error = CompletionError;
509
510 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
511 let (content, usage) = match &response {
512 CompletionResponse::Structured { choices, usage, .. } => {
513 let choice = choices.first().ok_or_else(|| {
514 CompletionError::ResponseError("Response contained no choices".to_owned())
515 })?;
516
517 let usage = usage
518 .as_ref()
519 .map(|usage| completion::Usage {
520 input_tokens: usage.prompt_tokens as u64,
521 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
522 total_tokens: usage.total_tokens as u64,
523 cached_input_tokens: 0,
524 cache_creation_input_tokens: 0,
525 tool_use_prompt_tokens: 0,
526 reasoning_tokens: 0,
527 })
528 .unwrap_or_default();
529
530 let message = message::Message::try_from(choice.message.clone())?;
532
533 let content = match message {
534 Message::Assistant { content, .. } => {
535 if content.is_empty() {
536 return Err(CompletionError::ResponseError(
537 "Response contained empty content".to_owned(),
538 ));
539 }
540
541 for c in content.iter() {
543 if !matches!(c, AssistantContent::Text(_)) {
544 tracing::warn!(target: "rig",
545 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
546 );
547 }
548 }
549
550 content.iter().map(|c| {
551 match c {
552 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
553 other => Err(CompletionError::ResponseError(
554 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
555 ))
556 }
557 }).collect::<Result<Vec<_>, _>>()?
558 }
559 Message::User { .. } => {
560 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
561 return Err(CompletionError::ResponseError(
562 "Received user message in response where assistant message was expected".to_owned()
563 ));
564 }
565 Message::System { .. } => {
566 tracing::warn!(target: "rig", "Received system message in response where assistant message was expected");
567 return Err(CompletionError::ResponseError(
568 "Received system message in response where assistant message was expected".to_owned(),
569 ));
570 }
571 };
572
573 (content, usage)
574 }
575 CompletionResponse::Simple(text) => (
576 vec![completion::AssistantContent::text(text)],
577 completion::Usage::new(),
578 ),
579 };
580
581 let choice = OneOrMany::many(content).map_err(|_| {
582 CompletionError::ResponseError(
583 "Response contained no message or tool call (empty)".to_owned(),
584 )
585 })?;
586
587 Ok(completion::CompletionResponse {
588 choice,
589 usage,
590 raw_response: response,
591 message_id: None,
592 })
593 }
594}
595
596#[derive(Clone, Debug, Deserialize, Serialize)]
597pub struct Usage {
598 pub prompt_tokens: usize,
599 pub total_tokens: usize,
600}
601
602impl std::fmt::Display for Usage {
603 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604 write!(
605 f,
606 "Prompt tokens: {} Total tokens: {}",
607 self.prompt_tokens, self.total_tokens
608 )
609 }
610}
611
612impl From<Message> for serde_json::Value {
613 fn from(msg: Message) -> Self {
614 match msg {
615 Message::System { content } => serde_json::json!({
616 "role": "system",
617 "content": content
618 }),
619 Message::User { content } => {
620 let text = content
621 .iter()
622 .map(|c| match c {
623 UserContent::Text(text) => &text.text,
624 _ => "",
625 })
626 .collect::<Vec<_>>()
627 .join("\n");
628 serde_json::json!({
629 "role": "user",
630 "content": text
631 })
632 }
633 Message::Assistant { content, .. } => {
634 let text = content
635 .iter()
636 .map(|c| match c {
637 AssistantContent::Text(text) => &text.text,
638 _ => "",
639 })
640 .collect::<Vec<_>>()
641 .join("\n");
642 serde_json::json!({
643 "role": "assistant",
644 "content": text
645 })
646 }
647 }
648 }
649}
650
651impl TryFrom<serde_json::Value> for Message {
652 type Error = CompletionError;
653
654 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
655 let role = value
656 .get("role")
657 .and_then(serde_json::Value::as_str)
658 .ok_or_else(|| {
659 CompletionError::ResponseError("Message missing role field".to_owned())
660 })?;
661
662 let content = match value.get("content") {
664 Some(content) => match content {
665 serde_json::Value::String(s) => s.clone(),
666 serde_json::Value::Array(arr) => arr
667 .iter()
668 .filter_map(|c| {
669 c.get("text")
670 .and_then(|t| t.as_str())
671 .map(|text| text.to_string())
672 })
673 .collect::<Vec<_>>()
674 .join("\n"),
675 _ => {
676 return Err(CompletionError::ResponseError(
677 "Message content must be string or array".to_owned(),
678 ));
679 }
680 },
681 None => {
682 return Err(CompletionError::ResponseError(
683 "Message missing content field".to_owned(),
684 ));
685 }
686 };
687
688 match role {
689 "system" => Ok(Message::System { content }),
690 "user" => Ok(Message::User {
691 content: OneOrMany::one(UserContent::Text(message::Text::new(content))),
692 }),
693 "assistant" => Ok(Message::Assistant {
694 id: None,
695 content: OneOrMany::one(AssistantContent::Text(message::Text::new(content))),
696 }),
697 _ => Err(CompletionError::ResponseError(format!(
698 "Unsupported message role: {role}"
699 ))),
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use crate::message::UserContent;
708 use serde_json::json;
709
710 #[test]
711 fn test_deserialize_message() {
712 let assistant_message_json = json!({
714 "role": "assistant",
715 "content": "Hello there, how may I assist you today?"
716 });
717
718 let user_message_json = json!({
719 "role": "user",
720 "content": "What can you help me with?"
721 });
722
723 let assistant_message_array_json = json!({
725 "role": "assistant",
726 "content": [{
727 "type": "text",
728 "text": "Hello there, how may I assist you today?"
729 }]
730 });
731
732 let assistant_message = Message::try_from(assistant_message_json).unwrap();
733 let user_message = Message::try_from(user_message_json).unwrap();
734 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
735
736 match assistant_message {
738 Message::Assistant { content, .. } => {
739 assert_eq!(
740 content.first(),
741 AssistantContent::Text(message::Text::new(
742 "Hello there, how may I assist you today?".to_string()
743 ))
744 );
745 }
746 _ => panic!("Expected assistant message"),
747 }
748
749 match user_message {
750 Message::User { content } => {
751 assert_eq!(
752 content.first(),
753 UserContent::Text(message::Text::new("What can you help me with?".to_string()))
754 );
755 }
756 _ => panic!("Expected user message"),
757 }
758
759 match assistant_message_array {
761 Message::Assistant { content, .. } => {
762 assert_eq!(
763 content.first(),
764 AssistantContent::Text(message::Text::new(
765 "Hello there, how may I assist you today?".to_string()
766 ))
767 );
768 }
769 _ => panic!("Expected assistant message"),
770 }
771 }
772
773 #[test]
774 fn test_message_conversion() {
775 let original_message = message::Message::User {
777 content: OneOrMany::one(message::UserContent::text("Hello")),
778 };
779
780 let mira_value: serde_json::Value = original_message.clone().into();
782
783 let converted_message: Message = mira_value.try_into().unwrap();
785
786 assert_eq!(original_message, converted_message);
787 }
788
789 #[test]
790 fn test_completion_response_conversion() {
791 let mira_response = CompletionResponse::Structured {
792 id: "resp_123".to_string(),
793 object: "chat.completion".to_string(),
794 created: 1234567890,
795 model: "deepseek-r1".to_string(),
796 choices: vec![ChatChoice {
797 message: RawMessage {
798 role: "assistant".to_string(),
799 content: "Test response".to_string(),
800 },
801 finish_reason: Some("stop".to_string()),
802 index: Some(0),
803 }],
804 usage: Some(Usage {
805 prompt_tokens: 10,
806 total_tokens: 20,
807 }),
808 };
809
810 let completion_response: completion::CompletionResponse<CompletionResponse> =
811 mira_response.try_into().unwrap();
812
813 assert_eq!(
814 completion_response.choice.first(),
815 completion::AssistantContent::text("Test response")
816 );
817 }
818 #[test]
819 fn test_client_initialization() {
820 let _client =
821 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
822 let _client_from_builder = crate::providers::mira::Client::builder()
823 .api_key("dummy-key")
824 .build()
825 .expect("Client::builder() failed");
826 }
827}