1use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16 OneOrMany,
17 completion::{self, CompletionError, CompletionRequest},
18 impl_conversion_traits,
19 message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::Deserialize;
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 id: None,
65 content: OneOrMany::one(AssistantContent::Text(message::Text {
66 text: raw.content,
67 })),
68 }),
69 _ => Err(CompletionError::ResponseError(format!(
70 "Unsupported message role: {}",
71 raw.role
72 ))),
73 }
74 }
75}
76
77#[derive(Debug, Deserialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80 Structured {
81 id: String,
82 object: String,
83 created: u64,
84 model: String,
85 choices: Vec<ChatChoice>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 usage: Option<Usage>,
88 },
89 Simple(String),
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ChatChoice {
94 pub message: RawMessage,
95 #[serde(default)]
96 pub finish_reason: Option<String>,
97 #[serde(default)]
98 pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize)]
102struct ModelsResponse {
103 data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize)]
107struct ModelInfo {
108 id: String,
109}
110
111#[derive(Clone)]
112pub struct Client {
114 base_url: String,
115 http_client: reqwest::Client,
116 api_key: String,
117 headers: HeaderMap,
118}
119
120impl std::fmt::Debug for Client {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("Client")
123 .field("base_url", &self.base_url)
124 .field("http_client", &self.http_client)
125 .field("api_key", &"<REDACTED>")
126 .field("headers", &self.headers)
127 .finish()
128 }
129}
130
131impl Client {
132 pub fn new(api_key: &str) -> Result<Self, MiraError> {
134 let mut headers = HeaderMap::new();
135 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136 headers.insert(
137 reqwest::header::ACCEPT,
138 HeaderValue::from_static("application/json"),
139 );
140 headers.insert(
141 reqwest::header::USER_AGENT,
142 HeaderValue::from_static("rig-client/1.0"),
143 );
144
145 Ok(Self {
146 base_url: MIRA_API_BASE_URL.to_string(),
147 api_key: api_key.to_string(),
148 http_client: reqwest::Client::builder()
149 .build()
150 .expect("Failed to build HTTP client"),
151 headers,
152 })
153 }
154
155 pub fn new_with_base_url(
157 api_key: &str,
158 base_url: impl Into<String>,
159 ) -> Result<Self, MiraError> {
160 let mut client = Self::new(api_key)?;
161 client.base_url = base_url.into();
162 Ok(client)
163 }
164
165 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
168 self.http_client = client;
169
170 self
171 }
172
173 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
175 let url = format!("{}/v1/models", self.base_url);
176
177 let response = self
178 .http_client
179 .get(&url)
180 .bearer_auth(&self.api_key)
181 .headers(self.headers.clone())
182 .send()
183 .await?;
184
185 let status = response.status();
186
187 if !status.is_success() {
188 let _error_text = response.text().await.unwrap_or_default();
190 tracing::error!("Error response: {}", _error_text);
191 return Err(MiraError::ApiError(status.as_u16()));
192 }
193
194 let response_text = response.text().await?;
195
196 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
197 tracing::error!("Failed to parse response: {}", e);
198 MiraError::JsonError(e)
199 })?;
200
201 Ok(models.data.into_iter().map(|model| model.id).collect())
202 }
203}
204
205impl ProviderClient for Client {
206 fn from_env() -> Self {
209 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
210 Self::new(&api_key).expect("Could not create Mira Client")
211 }
212}
213
214impl CompletionClient for Client {
215 type CompletionModel = CompletionModel;
216 fn completion_model(&self, model: &str) -> CompletionModel {
218 CompletionModel::new(self.to_owned(), model)
219 }
220}
221
222impl_conversion_traits!(
223 AsEmbeddings,
224 AsTranscription,
225 AsImageGeneration,
226 AsAudioGeneration for Client
227);
228
229#[derive(Clone)]
230pub struct CompletionModel {
231 client: Client,
232 pub model: String,
234}
235
236impl CompletionModel {
237 pub fn new(client: Client, model: &str) -> Self {
238 Self {
239 client,
240 model: model.to_string(),
241 }
242 }
243
244 fn create_completion_request(
245 &self,
246 completion_request: CompletionRequest,
247 ) -> Result<Value, CompletionError> {
248 let mut messages = Vec::new();
249
250 if let Some(preamble) = &completion_request.preamble {
252 messages.push(serde_json::json!({
253 "role": "user",
254 "content": preamble.to_string()
255 }));
256 }
257
258 if let Some(Message::User { content }) = completion_request.normalized_documents() {
260 let text = content
261 .into_iter()
262 .filter_map(|doc| match doc {
263 UserContent::Document(doc) => Some(doc.data),
264 UserContent::Text(text) => Some(text.text),
265
266 _ => None,
268 })
269 .collect::<Vec<_>>()
270 .join("\n");
271
272 messages.push(serde_json::json!({
273 "role": "user",
274 "content": text
275 }));
276 }
277
278 for msg in completion_request.chat_history {
280 let (role, content) = match msg {
281 Message::User { content } => {
282 let text = content
283 .iter()
284 .map(|c| match c {
285 UserContent::Text(text) => &text.text,
286 _ => "",
287 })
288 .collect::<Vec<_>>()
289 .join("\n");
290 ("user", text)
291 }
292 Message::Assistant { content, .. } => {
293 let text = content
294 .iter()
295 .map(|c| match c {
296 AssistantContent::Text(text) => &text.text,
297 _ => "",
298 })
299 .collect::<Vec<_>>()
300 .join("\n");
301 ("assistant", text)
302 }
303 };
304 messages.push(serde_json::json!({
305 "role": role,
306 "content": content
307 }));
308 }
309
310 let request = serde_json::json!({
311 "model": self.model,
312 "messages": messages,
313 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
314 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
315 "stream": false
316 });
317
318 Ok(request)
319 }
320}
321
322impl completion::CompletionModel for CompletionModel {
323 type Response = CompletionResponse;
324 type StreamingResponse = openai::StreamingCompletionResponse;
325
326 #[cfg_attr(feature = "worker", worker::send)]
327 async fn completion(
328 &self,
329 completion_request: CompletionRequest,
330 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
331 if !completion_request.tools.is_empty() {
332 tracing::warn!(target: "rig",
333 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
334 completion_request.tools.len()
335 );
336 }
337
338 let mira_request = self.create_completion_request(completion_request)?;
339
340 let response = self
341 .client
342 .http_client
343 .post(format!("{}/v1/chat/completions", self.client.base_url))
344 .bearer_auth(&self.client.api_key)
345 .headers(self.client.headers.clone())
346 .json(&mira_request)
347 .send()
348 .await
349 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
350
351 if !response.status().is_success() {
352 let status = response.status().as_u16();
353 let error_text = response.text().await.unwrap_or_default();
354 return Err(CompletionError::ProviderError(format!(
355 "API error: {status} - {error_text}"
356 )));
357 }
358
359 let response: CompletionResponse = response
360 .json()
361 .await
362 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
363
364 response.try_into()
365 }
366
367 #[cfg_attr(feature = "worker", worker::send)]
368 async fn stream(
369 &self,
370 completion_request: CompletionRequest,
371 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
372 let mut request = self.create_completion_request(completion_request)?;
373
374 request = merge(request, json!({"stream": true}));
375
376 let builder = self
377 .client
378 .http_client
379 .post(format!("{}/v1/chat/completions", self.client.base_url))
380 .headers(self.client.headers.clone())
381 .json(&request);
382
383 send_compatible_streaming_request(builder).await
384 }
385}
386
387impl From<ApiErrorResponse> for CompletionError {
388 fn from(err: ApiErrorResponse) -> Self {
389 CompletionError::ProviderError(err.message)
390 }
391}
392
393impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
394 type Error = CompletionError;
395
396 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
397 let content = match &response {
398 CompletionResponse::Structured { choices, .. } => {
399 let choice = choices.first().ok_or_else(|| {
400 CompletionError::ResponseError("Response contained no choices".to_owned())
401 })?;
402
403 let message = message::Message::try_from(choice.message.clone())?;
405
406 match message {
407 Message::Assistant { content, .. } => {
408 if content.is_empty() {
409 return Err(CompletionError::ResponseError(
410 "Response contained empty content".to_owned(),
411 ));
412 }
413
414 for c in content.iter() {
416 if !matches!(c, AssistantContent::Text(_)) {
417 tracing::warn!(target: "rig",
418 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
419 );
420 }
421 }
422
423 content.iter().map(|c| {
424 match c {
425 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
426 other => Err(CompletionError::ResponseError(
427 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
428 ))
429 }
430 }).collect::<Result<Vec<_>, _>>()?
431 }
432 Message::User { .. } => {
433 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
434 return Err(CompletionError::ResponseError(
435 "Received user message in response where assistant message was expected".to_owned()
436 ));
437 }
438 }
439 }
440 CompletionResponse::Simple(text) => {
441 vec![completion::AssistantContent::text(text)]
442 }
443 };
444
445 let choice = OneOrMany::many(content).map_err(|_| {
446 CompletionError::ResponseError(
447 "Response contained no message or tool call (empty)".to_owned(),
448 )
449 })?;
450
451 Ok(completion::CompletionResponse {
452 choice,
453 raw_response: response,
454 })
455 }
456}
457
458#[derive(Clone, Debug, Deserialize)]
459pub struct Usage {
460 pub prompt_tokens: usize,
461 pub total_tokens: usize,
462}
463
464impl std::fmt::Display for Usage {
465 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466 write!(
467 f,
468 "Prompt tokens: {} Total tokens: {}",
469 self.prompt_tokens, self.total_tokens
470 )
471 }
472}
473
474impl From<Message> for serde_json::Value {
475 fn from(msg: Message) -> Self {
476 match msg {
477 Message::User { content } => {
478 let text = content
479 .iter()
480 .map(|c| match c {
481 UserContent::Text(text) => &text.text,
482 _ => "",
483 })
484 .collect::<Vec<_>>()
485 .join("\n");
486 serde_json::json!({
487 "role": "user",
488 "content": text
489 })
490 }
491 Message::Assistant { content, .. } => {
492 let text = content
493 .iter()
494 .map(|c| match c {
495 AssistantContent::Text(text) => &text.text,
496 _ => "",
497 })
498 .collect::<Vec<_>>()
499 .join("\n");
500 serde_json::json!({
501 "role": "assistant",
502 "content": text
503 })
504 }
505 }
506 }
507}
508
509impl TryFrom<serde_json::Value> for Message {
510 type Error = CompletionError;
511
512 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
513 let role = value["role"].as_str().ok_or_else(|| {
514 CompletionError::ResponseError("Message missing role field".to_owned())
515 })?;
516
517 let content = match value.get("content") {
519 Some(content) => match content {
520 serde_json::Value::String(s) => s.clone(),
521 serde_json::Value::Array(arr) => arr
522 .iter()
523 .filter_map(|c| {
524 c.get("text")
525 .and_then(|t| t.as_str())
526 .map(|text| text.to_string())
527 })
528 .collect::<Vec<_>>()
529 .join("\n"),
530 _ => {
531 return Err(CompletionError::ResponseError(
532 "Message content must be string or array".to_owned(),
533 ));
534 }
535 },
536 None => {
537 return Err(CompletionError::ResponseError(
538 "Message missing content field".to_owned(),
539 ));
540 }
541 };
542
543 match role {
544 "user" => Ok(Message::User {
545 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
546 }),
547 "assistant" => Ok(Message::Assistant {
548 id: None,
549 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
550 }),
551 _ => Err(CompletionError::ResponseError(format!(
552 "Unsupported message role: {role}"
553 ))),
554 }
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::message::UserContent;
562 use serde_json::json;
563
564 #[test]
565 fn test_deserialize_message() {
566 let assistant_message_json = json!({
568 "role": "assistant",
569 "content": "Hello there, how may I assist you today?"
570 });
571
572 let user_message_json = json!({
573 "role": "user",
574 "content": "What can you help me with?"
575 });
576
577 let assistant_message_array_json = json!({
579 "role": "assistant",
580 "content": [{
581 "type": "text",
582 "text": "Hello there, how may I assist you today?"
583 }]
584 });
585
586 let assistant_message = Message::try_from(assistant_message_json).unwrap();
587 let user_message = Message::try_from(user_message_json).unwrap();
588 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
589
590 match assistant_message {
592 Message::Assistant { content, .. } => {
593 assert_eq!(
594 content.first(),
595 AssistantContent::Text(message::Text {
596 text: "Hello there, how may I assist you today?".to_string()
597 })
598 );
599 }
600 _ => panic!("Expected assistant message"),
601 }
602
603 match user_message {
604 Message::User { content } => {
605 assert_eq!(
606 content.first(),
607 UserContent::Text(message::Text {
608 text: "What can you help me with?".to_string()
609 })
610 );
611 }
612 _ => panic!("Expected user message"),
613 }
614
615 match assistant_message_array {
617 Message::Assistant { content, .. } => {
618 assert_eq!(
619 content.first(),
620 AssistantContent::Text(message::Text {
621 text: "Hello there, how may I assist you today?".to_string()
622 })
623 );
624 }
625 _ => panic!("Expected assistant message"),
626 }
627 }
628
629 #[test]
630 fn test_message_conversion() {
631 let original_message = message::Message::User {
633 content: OneOrMany::one(message::UserContent::text("Hello")),
634 };
635
636 let mira_value: serde_json::Value = original_message.clone().into();
638
639 let converted_message: Message = mira_value.try_into().unwrap();
641
642 assert_eq!(original_message, converted_message);
643 }
644
645 #[test]
646 fn test_completion_response_conversion() {
647 let mira_response = CompletionResponse::Structured {
648 id: "resp_123".to_string(),
649 object: "chat.completion".to_string(),
650 created: 1234567890,
651 model: "deepseek-r1".to_string(),
652 choices: vec![ChatChoice {
653 message: RawMessage {
654 role: "assistant".to_string(),
655 content: "Test response".to_string(),
656 },
657 finish_reason: Some("stop".to_string()),
658 index: Some(0),
659 }],
660 usage: Some(Usage {
661 prompt_tokens: 10,
662 total_tokens: 20,
663 }),
664 };
665
666 let completion_response: completion::CompletionResponse<CompletionResponse> =
667 mira_response.try_into().unwrap();
668
669 assert_eq!(
670 completion_response.choice.first(),
671 completion::AssistantContent::text("Test response")
672 );
673 }
674}