mubit-sdk 0.7.0

Umbrella Rust SDK for Mubit core/control planes
Documentation
//! MuBit Learn Run Manager.
//!
//! Manages session lifecycle: call counting, background reflection.

use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};

use crate::learn::LearnConfig;

pub(crate) struct RunManagerInner {
    pub session_id: String,
    pub config: LearnConfig,
    call_count: AtomicU64,
    ended: AtomicBool,
    http_client: reqwest::Client,
}

impl RunManagerInner {
    pub fn new(config: LearnConfig) -> Self {
        let session_id = config.session_id.clone().unwrap_or_else(|| {
            uuid::Uuid::new_v4()
                .to_string()
                .replace('-', "")
                .chars()
                .take(16)
                .collect()
        });

        Self {
            session_id,
            config,
            call_count: AtomicU64::new(0),
            ended: AtomicBool::new(false),
            http_client: reqwest::Client::builder()
                .timeout(std::time::Duration::from_secs(5))
                .build()
                .unwrap_or_default(),
        }
    }

    pub fn increment(&self) {
        let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1;
        if let Some(n) = self.config.reflect_after_n_calls {
            if n > 0 && count % n == 0 {
                self.background_reflect();
            }
        }
    }

    pub fn call_count(&self) -> u64 {
        self.call_count.load(Ordering::Relaxed)
    }

    pub async fn end(&self) {
        if self.ended.swap(true, Ordering::SeqCst) {
            return;
        }
        if self.config.auto_reflect {
            let _ = self.do_reflect().await;
        }
    }

    pub fn background_reflect(&self) {
        let endpoint = self.config.endpoint.clone();
        let api_key = self.config.api_key.clone();
        let session_id = self.session_id.clone();
        let client = self.http_client.clone();

        tokio::spawn(async move {
            let _ = Self::reflect_http(&client, &endpoint, &api_key, &session_id).await;
        });
    }

    async fn do_reflect(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        Self::reflect_http(
            &self.http_client,
            &self.config.endpoint,
            &self.config.api_key,
            &self.session_id,
        )
        .await
    }

    async fn reflect_http(
        client: &reqwest::Client,
        endpoint: &str,
        api_key: &str,
        session_id: &str,
    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        if api_key.is_empty() {
            return Ok(());
        }
        let url = format!("{}/v2/control/reflect", endpoint.trim_end_matches('/'));
        client
            .post(&url)
            .bearer_auth(api_key)
            .json(&serde_json::json!({"run_id": session_id}))
            .send()
            .await?;
        Ok(())
    }

    pub async fn get_context_http(
        &self,
        query: &str,
    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
        if self.config.api_key.is_empty() {
            return Ok(String::new());
        }

        let url = format!(
            "{}/v2/control/context",
            self.config.endpoint.trim_end_matches('/')
        );

        let mut payload = serde_json::json!({
            "run_id": self.session_id,
            "query": query,
            "format": "structured",
            "max_token_budget": self.config.max_token_budget,
        });

        if !self.config.entry_types.is_empty() {
            payload["entry_types"] = serde_json::json!(self.config.entry_types);
        }
        if !self.config.context_sections.is_empty() {
            payload["sections"] = serde_json::json!(self.config.context_sections);
        }

        let resp = self
            .http_client
            .post(&url)
            .bearer_auth(&self.config.api_key)
            .json(&payload)
            .send()
            .await?;

        if !resp.status().is_success() {
            return Ok(String::new());
        }

        let body: serde_json::Value = resp.json().await?;
        Ok(body
            .get("context_block")
            .and_then(|v| v.as_str())
            .unwrap_or("")
            .to_string())
    }
}