1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40
41 fn build<H>(
42 _: &crate::client::ClientBuilder<
43 Self::Builder,
44 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
45 H,
46 >,
47 ) -> http_client::Result<Self> {
48 Ok(Self)
49 }
50}
51
52impl<H> Capabilities<H> for MiraExt {
53 type Completion = Capable<CompletionModel<H>>;
54 type Embeddings = Nothing;
55 type Transcription = Nothing;
56
57 #[cfg(feature = "image")]
58 type ImageGeneration = Nothing;
59
60 #[cfg(feature = "audio")]
61 type AudioGeneration = Nothing;
62}
63
64impl DebugExt for MiraExt {}
65
66impl ProviderBuilder for MiraBuilder {
67 type Output = MiraExt;
68 type ApiKey = MiraApiKey;
69
70 const BASE_URL: &'static str = MIRA_API_BASE_URL;
71}
72
73pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
74pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
75
76#[derive(Debug, Error)]
77pub enum MiraError {
78 #[error("Invalid API key")]
79 InvalidApiKey,
80 #[error("API error: {0}")]
81 ApiError(u16),
82 #[error("Request error: {0}")]
83 RequestError(#[from] http_client::Error),
84 #[error("UTF-8 error: {0}")]
85 Utf8Error(#[from] FromUtf8Error),
86 #[error("JSON error: {0}")]
87 JsonError(#[from] serde_json::Error),
88}
89
90#[derive(Debug, Deserialize)]
91struct ApiErrorResponse {
92 message: String,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96pub struct RawMessage {
97 pub role: String,
98 pub content: String,
99}
100
101const MIRA_API_BASE_URL: &str = "https://api.mira.network";
102
103impl TryFrom<RawMessage> for message::Message {
104 type Error = CompletionError;
105
106 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
107 match raw.role.as_str() {
108 "user" => Ok(message::Message::User {
109 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
110 }),
111 "assistant" => Ok(message::Message::Assistant {
112 id: None,
113 content: OneOrMany::one(AssistantContent::Text(message::Text {
114 text: raw.content,
115 })),
116 }),
117 _ => Err(CompletionError::ResponseError(format!(
118 "Unsupported message role: {}",
119 raw.role
120 ))),
121 }
122 }
123}
124
125#[derive(Debug, Deserialize, Serialize)]
126#[serde(untagged)]
127pub enum CompletionResponse {
128 Structured {
129 id: String,
130 object: String,
131 created: u64,
132 model: String,
133 choices: Vec<ChatChoice>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 usage: Option<Usage>,
136 },
137 Simple(String),
138}
139
140#[derive(Debug, Deserialize, Serialize)]
141pub struct ChatChoice {
142 pub message: RawMessage,
143 #[serde(default)]
144 pub finish_reason: Option<String>,
145 #[serde(default)]
146 pub index: Option<usize>,
147}
148
149#[derive(Debug, Deserialize, Serialize)]
150struct ModelsResponse {
151 data: Vec<ModelInfo>,
152}
153
154#[derive(Debug, Deserialize, Serialize)]
155struct ModelInfo {
156 id: String,
157}
158
159impl<T> Client<T>
160where
161 T: HttpClientExt + 'static,
162{
163 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165 let req = self.get("/v1/models").and_then(|req| {
166 req.body(http_client::NoBody)
167 .map_err(http_client::Error::Protocol)
168 })?;
169
170 let response = self.send(req).await?;
171
172 let status = response.status();
173
174 if !status.is_success() {
175 let error_text = http_client::text(response).await.unwrap_or_default();
177 tracing::error!("Error response: {}", error_text);
178 return Err(MiraError::ApiError(status.as_u16()));
179 }
180
181 let response_text = http_client::text(response).await?;
182
183 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
184 tracing::error!("Failed to parse response: {}", e);
185 MiraError::JsonError(e)
186 })?;
187
188 Ok(models.data.into_iter().map(|model| model.id).collect())
189 }
190}
191
192impl ProviderClient for Client {
193 type Input = String;
194
195 fn from_env() -> Self {
198 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
199 Self::new(&api_key).unwrap()
200 }
201
202 fn from_val(input: Self::Input) -> Self {
203 Self::new(&input).unwrap()
204 }
205}
206
207#[derive(Debug, Serialize, Deserialize)]
208pub(super) struct MiraCompletionRequest {
209 model: String,
210 pub messages: Vec<RawMessage>,
211 #[serde(flatten, skip_serializing_if = "Option::is_none")]
212 temperature: Option<f64>,
213 #[serde(flatten, skip_serializing_if = "Option::is_none")]
214 max_tokens: Option<u64>,
215 pub stream: bool,
216}
217
218impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
219 type Error = CompletionError;
220
221 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
222 let mut messages = Vec::new();
223
224 if let Some(content) = &req.preamble {
225 messages.push(RawMessage {
226 role: "user".to_string(),
227 content: content.to_string(),
228 });
229 }
230
231 if let Some(Message::User { content }) = req.normalized_documents() {
232 let text = content
233 .into_iter()
234 .filter_map(|doc| match doc {
235 UserContent::Document(Document {
236 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
237 ..
238 }) => Some(data),
239 UserContent::Text(text) => Some(text.text),
240
241 _ => None,
243 })
244 .collect::<Vec<_>>()
245 .join("\n");
246
247 messages.push(RawMessage {
248 role: "user".to_string(),
249 content: text,
250 });
251 }
252
253 for msg in req.chat_history {
254 let (role, content) = match msg {
255 Message::User { content } => {
256 let text = content
257 .iter()
258 .map(|c| match c {
259 UserContent::Text(text) => &text.text,
260 _ => "",
261 })
262 .collect::<Vec<_>>()
263 .join("\n");
264 ("user", text)
265 }
266 Message::Assistant { content, .. } => {
267 let text = content
268 .iter()
269 .map(|c| match c {
270 AssistantContent::Text(text) => &text.text,
271 _ => "",
272 })
273 .collect::<Vec<_>>()
274 .join("\n");
275 ("assistant", text)
276 }
277 };
278 messages.push(RawMessage {
279 role: role.to_string(),
280 content,
281 });
282 }
283
284 Ok(Self {
285 model: model.to_string(),
286 messages,
287 temperature: req.temperature,
288 max_tokens: req.max_tokens,
289 stream: false,
290 })
291 }
292}
293
294#[derive(Clone)]
295pub struct CompletionModel<T = reqwest::Client> {
296 client: Client<T>,
297 pub model: String,
299}
300
301impl<T> CompletionModel<T> {
302 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
303 Self {
304 client,
305 model: model.into(),
306 }
307 }
308}
309
310impl<T> completion::CompletionModel for CompletionModel<T>
311where
312 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
313{
314 type Response = CompletionResponse;
315 type StreamingResponse = openai::StreamingCompletionResponse;
316
317 type Client = Client<T>;
318
319 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
320 Self::new(client.clone(), model)
321 }
322
323 #[cfg_attr(feature = "worker", worker::send)]
324 async fn completion(
325 &self,
326 completion_request: CompletionRequest,
327 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
328 let span = if tracing::Span::current().is_disabled() {
329 info_span!(
330 target: "rig::completions",
331 "chat",
332 gen_ai.operation.name = "chat",
333 gen_ai.provider.name = "mira",
334 gen_ai.request.model = self.model,
335 gen_ai.system_instructions = tracing::field::Empty,
336 gen_ai.response.id = tracing::field::Empty,
337 gen_ai.response.model = tracing::field::Empty,
338 gen_ai.usage.output_tokens = tracing::field::Empty,
339 gen_ai.usage.input_tokens = tracing::field::Empty,
340 )
341 } else {
342 tracing::Span::current()
343 };
344
345 span.record("gen_ai.system_instructions", &completion_request.preamble);
346
347 if !completion_request.tools.is_empty() {
348 tracing::warn!(target: "rig::completions",
349 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
350 len = completion_request.tools.len()
351 );
352 }
353
354 if completion_request.tool_choice.is_some() {
355 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
356 }
357
358 if completion_request.additional_params.is_some() {
359 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
360 }
361
362 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
363
364 if tracing::enabled!(tracing::Level::TRACE) {
365 tracing::trace!(target: "rig::completions",
366 "Mira completion request: {}",
367 serde_json::to_string_pretty(&request)?
368 );
369 }
370
371 let body = serde_json::to_vec(&request)?;
372
373 let req = self
374 .client
375 .post("/v1/chat/completions")?
376 .body(body)
377 .map_err(http_client::Error::from)?;
378
379 let async_block = async move {
380 let response = self
381 .client
382 .send::<_, bytes::Bytes>(req)
383 .await
384 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
385
386 let status = response.status();
387 let response_body = response.into_body().into_future().await?.to_vec();
388
389 if !status.is_success() {
390 let status = status.as_u16();
391 let error_text = String::from_utf8_lossy(&response_body).to_string();
392 return Err(CompletionError::ProviderError(format!(
393 "API error: {status} - {error_text}"
394 )));
395 }
396
397 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
398
399 if tracing::enabled!(tracing::Level::TRACE) {
400 tracing::trace!(target: "rig::completions",
401 "Mira completion response: {}",
402 serde_json::to_string_pretty(&response)?
403 );
404 }
405
406 if let CompletionResponse::Structured {
407 id, model, usage, ..
408 } = &response
409 {
410 let span = tracing::Span::current();
411 span.record("gen_ai.response.model_name", model);
412 span.record("gen_ai.response.id", id);
413 if let Some(usage) = usage {
414 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
415 span.record(
416 "gen_ai.usage.output_tokens",
417 usage.total_tokens - usage.prompt_tokens,
418 );
419 }
420 }
421
422 response.try_into()
423 };
424
425 async_block.instrument(span).await
426 }
427
428 #[cfg_attr(feature = "worker", worker::send)]
429 async fn stream(
430 &self,
431 completion_request: CompletionRequest,
432 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
433 let span = if tracing::Span::current().is_disabled() {
434 info_span!(
435 target: "rig::completions",
436 "chat_streaming",
437 gen_ai.operation.name = "chat_streaming",
438 gen_ai.provider.name = "mira",
439 gen_ai.request.model = self.model,
440 gen_ai.system_instructions = tracing::field::Empty,
441 gen_ai.response.id = tracing::field::Empty,
442 gen_ai.response.model = tracing::field::Empty,
443 gen_ai.usage.output_tokens = tracing::field::Empty,
444 gen_ai.usage.input_tokens = tracing::field::Empty,
445 )
446 } else {
447 tracing::Span::current()
448 };
449
450 span.record("gen_ai.system_instructions", &completion_request.preamble);
451
452 if !completion_request.tools.is_empty() {
453 tracing::warn!(target: "rig::completions",
454 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
455 len = completion_request.tools.len()
456 );
457 }
458
459 if completion_request.tool_choice.is_some() {
460 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
461 }
462
463 if completion_request.additional_params.is_some() {
464 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
465 }
466 let mut request =
467 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
468 request.stream = true;
469
470 if tracing::enabled!(tracing::Level::TRACE) {
471 tracing::trace!(target: "rig::completions",
472 "Mira completion request: {}",
473 serde_json::to_string_pretty(&request)?
474 );
475 }
476
477 let body = serde_json::to_vec(&request)?;
478
479 let req = self
480 .client
481 .post("/v1/chat/completions")?
482 .body(body)
483 .map_err(http_client::Error::from)?;
484
485 send_compatible_streaming_request(self.client.clone(), req)
486 .instrument(span)
487 .await
488 }
489}
490
491impl From<ApiErrorResponse> for CompletionError {
492 fn from(err: ApiErrorResponse) -> Self {
493 CompletionError::ProviderError(err.message)
494 }
495}
496
497impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
498 type Error = CompletionError;
499
500 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
501 let (content, usage) = match &response {
502 CompletionResponse::Structured { choices, usage, .. } => {
503 let choice = choices.first().ok_or_else(|| {
504 CompletionError::ResponseError("Response contained no choices".to_owned())
505 })?;
506
507 let usage = usage
508 .as_ref()
509 .map(|usage| completion::Usage {
510 input_tokens: usage.prompt_tokens as u64,
511 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
512 total_tokens: usage.total_tokens as u64,
513 })
514 .unwrap_or_default();
515
516 let message = message::Message::try_from(choice.message.clone())?;
518
519 let content = match message {
520 Message::Assistant { content, .. } => {
521 if content.is_empty() {
522 return Err(CompletionError::ResponseError(
523 "Response contained empty content".to_owned(),
524 ));
525 }
526
527 for c in content.iter() {
529 if !matches!(c, AssistantContent::Text(_)) {
530 tracing::warn!(target: "rig",
531 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
532 );
533 }
534 }
535
536 content.iter().map(|c| {
537 match c {
538 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
539 other => Err(CompletionError::ResponseError(
540 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
541 ))
542 }
543 }).collect::<Result<Vec<_>, _>>()?
544 }
545 Message::User { .. } => {
546 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
547 return Err(CompletionError::ResponseError(
548 "Received user message in response where assistant message was expected".to_owned()
549 ));
550 }
551 };
552
553 (content, usage)
554 }
555 CompletionResponse::Simple(text) => (
556 vec![completion::AssistantContent::text(text)],
557 completion::Usage::new(),
558 ),
559 };
560
561 let choice = OneOrMany::many(content).map_err(|_| {
562 CompletionError::ResponseError(
563 "Response contained no message or tool call (empty)".to_owned(),
564 )
565 })?;
566
567 Ok(completion::CompletionResponse {
568 choice,
569 usage,
570 raw_response: response,
571 })
572 }
573}
574
575#[derive(Clone, Debug, Deserialize, Serialize)]
576pub struct Usage {
577 pub prompt_tokens: usize,
578 pub total_tokens: usize,
579}
580
581impl std::fmt::Display for Usage {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 write!(
584 f,
585 "Prompt tokens: {} Total tokens: {}",
586 self.prompt_tokens, self.total_tokens
587 )
588 }
589}
590
591impl From<Message> for serde_json::Value {
592 fn from(msg: Message) -> Self {
593 match msg {
594 Message::User { content } => {
595 let text = content
596 .iter()
597 .map(|c| match c {
598 UserContent::Text(text) => &text.text,
599 _ => "",
600 })
601 .collect::<Vec<_>>()
602 .join("\n");
603 serde_json::json!({
604 "role": "user",
605 "content": text
606 })
607 }
608 Message::Assistant { content, .. } => {
609 let text = content
610 .iter()
611 .map(|c| match c {
612 AssistantContent::Text(text) => &text.text,
613 _ => "",
614 })
615 .collect::<Vec<_>>()
616 .join("\n");
617 serde_json::json!({
618 "role": "assistant",
619 "content": text
620 })
621 }
622 }
623 }
624}
625
626impl TryFrom<serde_json::Value> for Message {
627 type Error = CompletionError;
628
629 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
630 let role = value["role"].as_str().ok_or_else(|| {
631 CompletionError::ResponseError("Message missing role field".to_owned())
632 })?;
633
634 let content = match value.get("content") {
636 Some(content) => match content {
637 serde_json::Value::String(s) => s.clone(),
638 serde_json::Value::Array(arr) => arr
639 .iter()
640 .filter_map(|c| {
641 c.get("text")
642 .and_then(|t| t.as_str())
643 .map(|text| text.to_string())
644 })
645 .collect::<Vec<_>>()
646 .join("\n"),
647 _ => {
648 return Err(CompletionError::ResponseError(
649 "Message content must be string or array".to_owned(),
650 ));
651 }
652 },
653 None => {
654 return Err(CompletionError::ResponseError(
655 "Message missing content field".to_owned(),
656 ));
657 }
658 };
659
660 match role {
661 "user" => Ok(Message::User {
662 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
663 }),
664 "assistant" => Ok(Message::Assistant {
665 id: None,
666 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
667 }),
668 _ => Err(CompletionError::ResponseError(format!(
669 "Unsupported message role: {role}"
670 ))),
671 }
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::message::UserContent;
679 use serde_json::json;
680
681 #[test]
682 fn test_deserialize_message() {
683 let assistant_message_json = json!({
685 "role": "assistant",
686 "content": "Hello there, how may I assist you today?"
687 });
688
689 let user_message_json = json!({
690 "role": "user",
691 "content": "What can you help me with?"
692 });
693
694 let assistant_message_array_json = json!({
696 "role": "assistant",
697 "content": [{
698 "type": "text",
699 "text": "Hello there, how may I assist you today?"
700 }]
701 });
702
703 let assistant_message = Message::try_from(assistant_message_json).unwrap();
704 let user_message = Message::try_from(user_message_json).unwrap();
705 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
706
707 match assistant_message {
709 Message::Assistant { content, .. } => {
710 assert_eq!(
711 content.first(),
712 AssistantContent::Text(message::Text {
713 text: "Hello there, how may I assist you today?".to_string()
714 })
715 );
716 }
717 _ => panic!("Expected assistant message"),
718 }
719
720 match user_message {
721 Message::User { content } => {
722 assert_eq!(
723 content.first(),
724 UserContent::Text(message::Text {
725 text: "What can you help me with?".to_string()
726 })
727 );
728 }
729 _ => panic!("Expected user message"),
730 }
731
732 match assistant_message_array {
734 Message::Assistant { content, .. } => {
735 assert_eq!(
736 content.first(),
737 AssistantContent::Text(message::Text {
738 text: "Hello there, how may I assist you today?".to_string()
739 })
740 );
741 }
742 _ => panic!("Expected assistant message"),
743 }
744 }
745
746 #[test]
747 fn test_message_conversion() {
748 let original_message = message::Message::User {
750 content: OneOrMany::one(message::UserContent::text("Hello")),
751 };
752
753 let mira_value: serde_json::Value = original_message.clone().into();
755
756 let converted_message: Message = mira_value.try_into().unwrap();
758
759 assert_eq!(original_message, converted_message);
760 }
761
762 #[test]
763 fn test_completion_response_conversion() {
764 let mira_response = CompletionResponse::Structured {
765 id: "resp_123".to_string(),
766 object: "chat.completion".to_string(),
767 created: 1234567890,
768 model: "deepseek-r1".to_string(),
769 choices: vec![ChatChoice {
770 message: RawMessage {
771 role: "assistant".to_string(),
772 content: "Test response".to_string(),
773 },
774 finish_reason: Some("stop".to_string()),
775 index: Some(0),
776 }],
777 usage: Some(Usage {
778 prompt_tokens: 10,
779 total_tokens: 20,
780 }),
781 };
782
783 let completion_response: completion::CompletionResponse<CompletionResponse> =
784 mira_response.try_into().unwrap();
785
786 assert_eq!(
787 completion_response.choice.first(),
788 completion::AssistantContent::text("Test response")
789 );
790 }
791}