muna 0.0.12

Run prediction functions in your Rust apps.
/*
*   Muna
*   Copyright © 2026 NatML Inc. All Rights Reserved.
*/

use crate::types::Acceleration;
use serde::{Deserialize, Serialize};

/// Chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionMessage {
    /// Message role.
    pub role: String,
    /// Message content.
    #[serde(default)]
    pub content: Option<String>,
}

/// Usage information for a chat completion request.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChatCompletionUsage {
    /// Number of tokens in the prompt.
    pub prompt_tokens: u64,
    /// Number of tokens in the generated completion.
    pub completion_tokens: u64,
    /// Total number of tokens used in the request.
    pub total_tokens: u64,
}

/// Chat completion choice.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
    /// Index of the choice in the list of choices.
    pub index: usize,
    /// Chat completion message generated by the model.
    pub message: ChatCompletionMessage,
    /// Reason the model stopped generating tokens.
    #[serde(default)]
    pub finish_reason: Option<String>,
    /// Log probability information for the choice.
    #[serde(default)]
    pub logprobs: Option<serde_json::Value>,
}

/// Chat completion response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletion {
    /// Object type, always `chat.completion`.
    pub object: String,
    /// Unique identifier for the chat completion.
    pub id: String,
    /// Model used for the chat completion.
    pub model: String,
    /// Generated chat completion choices.
    pub choices: Vec<ChatCompletionChoice>,
    /// Unix timestamp, in seconds, when the completion was created.
    pub created: u64,
    /// Usage statistics for the completion request.
    #[serde(default)]
    pub usage: Option<ChatCompletionUsage>,
}

/// Chat completion chunk delta.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChatCompletionDelta {
    /// Role of the author of this message delta.
    #[serde(default)]
    pub role: Option<String>,
    /// Content of the message delta.
    #[serde(default)]
    pub content: Option<String>,
}

/// Chat completion chunk choice.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunkChoice {
    /// Index of the choice in the list of choices.
    pub index: usize,
    /// Chat completion delta generated by the model.
    #[serde(default)]
    pub delta: Option<ChatCompletionDelta>,
    /// Reason the model stopped generating tokens.
    #[serde(default)]
    pub finish_reason: Option<String>,
    /// Log probability information for the choice.
    #[serde(default)]
    pub logprobs: Option<serde_json::Value>,
}

/// Chat completion chunk.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
    /// Object type, always `chat.completion.chunk`.
    pub object: String,
    /// Unique identifier for the chat completion. Each chunk has the same ID.
    pub id: String,
    /// Model used for the chat completion.
    pub model: String,
    /// Generated chat completion chunk choices.
    pub choices: Vec<ChatCompletionChunkChoice>,
    /// Unix timestamp, in seconds, when the chunk was created.
    pub created: u64,
    /// Usage statistics for the completion request.
    #[serde(default)]
    pub usage: Option<ChatCompletionUsage>,
}

/// Reasoning effort for reasoning models.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ChatCompletionReasoningEffort {
    #[serde(rename = "minimal")]
    Minimal,
    #[serde(rename = "low")]
    Low,
    #[serde(rename = "medium")]
    Medium,
    #[serde(rename = "high")]
    High,
    #[serde(rename = "xhigh")]
    XHigh,
}

impl ChatCompletionReasoningEffort {
    pub fn as_str(self) -> &'static str {
        match self {
            Self::Minimal => "minimal",
            Self::Low => "low",
            Self::Medium => "medium",
            Self::High => "high",
            Self::XHigh => "xhigh",
        }
    }
}

/// Parameters for creating a chat completion.
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionCreateParams {
    /// Chat predictor tag.
    pub model: String,
    /// Messages comprising the conversation so far.
    pub messages: Vec<ChatCompletionMessage>,
    /// Response format.
    pub response_format: Option<serde_json::Map<String, serde_json::Value>>,
    /// Reasoning effort for reasoning models.
    pub reasoning_effort: Option<ChatCompletionReasoningEffort>,
    /// Maximum completion tokens.
    pub max_completion_tokens: Option<i32>,
    /// Sampling temperature to use.
    pub temperature: Option<f32>,
    /// Nucleus sampling coefficient.
    pub top_p: Option<f32>,
    /// Token frequency penalty.
    pub frequency_penalty: Option<f32>,
    /// Token presence penalty.
    pub presence_penalty: Option<f32>,
    /// Prediction acceleration.
    pub acceleration: Option<Acceleration>,
}

/// Embedding data, either as float values or base64-encoded bytes.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingData {
    Float(Vec<f32>),
    Base64(String),
}

/// Embedding vector.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
    /// Object type, always `embedding`.
    pub object: String,
    /// Embedding vector as float values or a base64-encoded string.
    pub embedding: EmbeddingData,
    /// Index of the embedding in the response list.
    pub index: usize,
}

/// Usage information for an embedding request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
    /// Number of tokens in the input prompt.
    pub prompt_tokens: u64,
    /// Total number of tokens used in the request.
    pub total_tokens: u64,
}

/// Response from creating embeddings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingCreateResponse {
    pub object: String,
    pub model: String,
    pub data: Vec<Embedding>,
    pub usage: EmbeddingUsage,
}