muna 0.0.12

Run prediction functions in your Rust apps.
/*
*   Muna
*   Copyright © 2026 NatML Inc. All Rights Reserved.
*/

use futures_core::Stream;
use std::collections::HashMap;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;

use crate::c;
use crate::client::{MunaClient, MunaError, RequestInput, Result};
use crate::types::{Acceleration, Prediction, PredictionResource, Value};

/// Make local predictions.
#[derive(Clone)]
pub struct LocalPredictionService {
    client: Arc<MunaClient>,
    cache: Arc<tokio::sync::RwLock<HashMap<String, Arc<c::Predictor>>>>,
    cache_dir: PathBuf,
}

impl LocalPredictionService {

    pub fn new(client: Arc<MunaClient>) -> Self {
        let cache_dir = get_cache_dir();
        Self {
            client,
            cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
            cache_dir,
        }
    }

    /// Create a prediction.
    pub async fn create(
        &self,
        tag: &str,
        inputs: Option<HashMap<String, Value>>,
        acceleration: Option<Acceleration>,
        client_id: Option<String>,
        configuration_id: Option<String>,
    ) -> Result<Prediction> {
        let inputs = match inputs {
            Some(inputs) => inputs,
            None => {
                return self
                    .create_raw_prediction(tag, client_id, configuration_id)
                    .await
            }
        };
        if inputs.is_empty() {
            let prediction = self
                .create_raw_prediction(tag, client_id, configuration_id)
                .await?;
            self.create_cached_prediction(&prediction).await?;
            return Ok(prediction);
        }
        self.load_predictor(tag, &acceleration, client_id, configuration_id)
            .await?;
        let predictor = {
            let cache = self.cache.read().await;
            cache[tag].clone()
        };
        let input_map = c::ValueMap::from_dict(&inputs)?;
        let prediction = predictor.create_prediction(&input_map)?;
        Ok(to_prediction(tag, &prediction))
    }

    /// Stream a prediction.
    pub async fn stream(
        &self,
        tag: &str,
        inputs: HashMap<String, Value>,
        acceleration: Option<Acceleration>,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<Prediction>> + Send>>> {
        self.load_predictor(tag, &acceleration, None, None).await?;
        let tag = tag.to_string();
        let predictor = {
            let cache = self.cache.read().await;
            cache[tag.as_str()].clone()
        };
        let input_map = c::ValueMap::from_dict(&inputs)?;
        let stream_handle = c::PredictionStream::create(predictor.raw_ptr(), &input_map)?;
        let stream = async_stream::try_stream! {
            for prediction in stream_handle {
                let prediction = prediction?;
                yield to_prediction(&tag, &prediction);
            }
        };
        Ok(Box::pin(stream))
    }

    /// Delete a predictor that is loaded in memory.
    pub async fn delete(&self, tag: &str) -> Result<bool> {
        let mut cache = self.cache.write().await;
        Ok(cache.remove(tag).is_some())
    }

    async fn create_raw_prediction(
        &self,
        tag: &str,
        client_id: Option<String>,
        configuration_id: Option<String>,
    ) -> Result<Prediction> {
        let client_id = client_id
            .or_else(|| c::Configuration::get_client_id().ok())
            .unwrap_or_else(|| "rust".to_string());
        let configuration_id = configuration_id.or_else(|| c::Configuration::get_unique_id().ok());
        let mut body = serde_json::json!({
            "tag": tag,
            "clientId": client_id,
        });
        if let Some(config_id) = configuration_id {
            body["configurationId"] = serde_json::Value::String(config_id);
        }
        self.client
            .request(RequestInput::post("/predictions").body(body))
            .await
    }

    async fn load_predictor(
        &self,
        tag: &str,
        acceleration: &Option<Acceleration>,
        client_id: Option<String>,
        configuration_id: Option<String>,
    ) -> Result<()> {
        {
            let cache = self.cache.read().await;
            if cache.contains_key(tag) {
                return Ok(());
            }
        }
        let acceleration = acceleration.clone().unwrap_or(Acceleration::LocalAuto);
        let prediction = self
            .create_raw_prediction(tag, client_id, configuration_id)
            .await?;
        let prediction = self.create_cached_prediction(&prediction).await?;
        let config_token = prediction.configuration.as_deref().ok_or_else(|| {
            MunaError::Prediction(format!(
                "Failed to create {tag} prediction because configuration token is missing"
            ))
        })?;
        let mut configuration = c::Configuration::new()?;
        configuration.set_tag(tag)?;
        configuration.set_token(&config_token)?;
        configuration.set_acceleration(c::acceleration_to_c(&acceleration))?;
        if let Some(resources) = &prediction.resources {
            for resource in resources {
                configuration.add_resource(&resource.kind, &resource.url)?;
            }
        }
        let predictor = c::Predictor::new(&configuration)?;
        let mut cache = self.cache.write().await;
        cache.entry(tag.to_string()).or_insert(Arc::new(predictor));
        Ok(())
    }

    fn get_resource_path(&self, resource: &PredictionResource) -> PathBuf {
        let url = url::Url::parse(&resource.url).ok();
        let stem = url
            .as_ref()
            .and_then(|u| u.path_segments())
            .and_then(|s| s.last())
            .unwrap_or("resource");
        let mut path = self.cache_dir.join(stem);
        if let Some(name) = &resource.name {
            path = path.join(name);
        }
        path
    }

    /// Download a prediction's resources and return a new prediction whose
    /// resource URLs point to the downloaded local paths.
    async fn create_cached_prediction(&self, prediction: &Prediction) -> Result<Prediction> {
        let resources = match &prediction.resources {
            Some(resources) => {
                let mut materialized = Vec::with_capacity(resources.len());
                for resource in resources {
                    materialized.push(self.download_resource(resource).await?);
                }
                Some(materialized)
            }
            None => None,
        };
        Ok(Prediction {
            resources,
            ..prediction.clone()
        })
    }

    /// Download a single resource and return it with its URL set to the local
    /// downloaded path.
    async fn download_resource(&self, resource: &PredictionResource) -> Result<PredictionResource> {
        let path = self.get_resource_path(resource);
        if !path.exists() {
            self.client.download(&resource.url, &path).await?;
        }
        Ok(PredictionResource {
            url: path.to_string_lossy().into_owned(),
            ..resource.clone()
        })
    }
}

fn to_prediction(tag: &str, prediction: &c::Prediction) -> Prediction {
    let results = prediction.results().ok().map(|map| {
        let size = map.len();
        (0..size)
            .filter_map(|i| {
                let key = map.key(i).ok()?;
                let value = map.get(&key).ok()?;
                value.to_object().ok()
            })
            .collect()
    });
    Prediction {
        id: prediction.id().unwrap_or_default(),
        tag: tag.to_string(),
        created: chrono_now(),
        configuration: None,
        resources: None,
        results,
        latency: prediction.latency().ok(),
        error: prediction.error().ok().flatten(),
        logs: prediction.logs().ok().flatten(),
    }
}

fn get_cache_dir() -> PathBuf {
    let dir = get_muna_home().join("cache");
    let _ = std::fs::create_dir_all(&dir);
    dir
}

fn get_muna_home() -> PathBuf {
    let candidates = std::env::var("MUNA_HOME")
        .ok()
        .map(PathBuf::from)
        .into_iter()
        .chain(home::home_dir().map(|h| h.join(".fxn")))
        .chain(std::iter::once(std::env::temp_dir().join(".fxn")));
    for dir in candidates {
        if std::fs::create_dir_all(&dir).is_ok() {
            let test = dir.join(".muna_write_test");
            if std::fs::write(&test, "muna").is_ok() {
                let _ = std::fs::remove_file(&test);
                return dir;
            }
        }
    }
    std::env::temp_dir().join(".fxn")
}

fn chrono_now() -> String {
    let secs = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs();
    format!("{secs}")
}