reasonkit-core 0.1.8

The Reasoning Engine — Auditable Reasoning for Production AI | Rust-Native | Turn Prompts into Protocols
//! LLM Orchestration Integration using Rig
//!
//! This module provides high-level abstractions for building LLM-powered
//! reasoning pipelines using the Rig framework.
//!
//! # Features
//! - Multi-provider support (OpenAI, Anthropic, Ollama, etc.)
//! - RAG pipeline integration with reasonkit-mem
//! - ThinkTool-aware prompt construction
//! - Streaming response handling
//!
//! Enable with: `cargo build --features llm-orchestration`

use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

// Note: rig-core provides the underlying LLM orchestration framework.
// Access it directly via `use rig_core::*;` when the llm-orchestration feature is enabled.

/// Configuration for LLM orchestration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationConfig {
    /// Default provider to use
    pub default_provider: LlmProvider,
    /// Maximum tokens for responses
    pub max_tokens: usize,
    /// Temperature for sampling
    pub temperature: f32,
    /// Enable streaming responses
    pub streaming: bool,
    /// Timeout in seconds
    pub timeout_secs: u64,
}

impl Default for OrchestrationConfig {
    fn default() -> Self {
        Self {
            default_provider: LlmProvider::OpenAI,
            max_tokens: 4096,
            temperature: 0.7,
            streaming: true,
            timeout_secs: 120,
        }
    }
}

/// Supported LLM providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
    OpenAI,
    Anthropic,
    Ollama,
    DeepSeek,
    Groq,
    Custom,
}

/// Trait for ReasonKit-aware LLM agents
#[async_trait]
pub trait ReasoningAgent: Send + Sync {
    /// Execute a reasoning task with ThinkTool integration
    async fn reason(
        &self,
        prompt: &str,
        context: Option<&ReasoningContext>,
    ) -> Result<ReasoningResponse>;

    /// Execute with streaming response
    async fn reason_stream(
        &self,
        prompt: &str,
        context: Option<&ReasoningContext>,
    ) -> Result<ReasoningStream>;

    /// Get the provider for this agent
    fn provider(&self) -> LlmProvider;
}

/// Context for reasoning operations
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ReasoningContext {
    /// Retrieved documents from RAG
    pub documents: Vec<String>,
    /// Active ThinkTool profile
    pub thinktool_profile: Option<String>,
    /// Previous reasoning steps
    pub history: Vec<ReasoningStep>,
    /// Custom metadata
    pub metadata: std::collections::HashMap<String, serde_json::Value>,
}

/// A single reasoning step
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningStep {
    pub tool: String,
    pub input: String,
    pub output: String,
    pub confidence: f32,
    pub timestamp: chrono::DateTime<chrono::Utc>,
}

/// Response from a reasoning operation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningResponse {
    pub content: String,
    pub model: String,
    pub provider: LlmProvider,
    pub tokens_used: TokenUsage,
    pub reasoning_steps: Vec<ReasoningStep>,
    pub confidence: f32,
}

/// Token usage statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
    pub prompt_tokens: usize,
    pub completion_tokens: usize,
    pub total_tokens: usize,
}

/// Streaming response wrapper
pub struct ReasoningStream {
    // Implementation would wrap rig's streaming types
    _inner: (),
}

/// Builder for creating reasoning pipelines
pub struct PipelineBuilder {
    config: OrchestrationConfig,
    tools: Vec<String>,
    rag_enabled: bool,
}

impl PipelineBuilder {
    pub fn new() -> Self {
        Self {
            config: OrchestrationConfig::default(),
            tools: Vec::new(),
            rag_enabled: false,
        }
    }

    pub fn with_config(mut self, config: OrchestrationConfig) -> Self {
        self.config = config;
        self
    }

    pub fn with_thinktool(mut self, tool: &str) -> Self {
        self.tools.push(tool.to_string());
        self
    }

    pub fn with_rag(mut self) -> Self {
        self.rag_enabled = true;
        self
    }

    pub fn build(self) -> Result<ReasoningPipeline> {
        Ok(ReasoningPipeline {
            config: self.config,
            tools: self.tools,
            rag_enabled: self.rag_enabled,
        })
    }
}

impl Default for PipelineBuilder {
    fn default() -> Self {
        Self::new()
    }
}

/// A configured reasoning pipeline
pub struct ReasoningPipeline {
    config: OrchestrationConfig,
    tools: Vec<String>,
    rag_enabled: bool,
}

impl ReasoningPipeline {
    /// Execute the pipeline with the given input
    pub async fn execute(&self, input: &str) -> Result<ReasoningResponse> {
        // Pipeline execution logic would go here
        // This integrates with Rig's pipeline system
        tracing::info!(
            provider = ?self.config.default_provider,
            tools = ?self.tools,
            rag = self.rag_enabled,
            "Executing reasoning pipeline"
        );

        Ok(ReasoningResponse {
            content: format!("Pipeline executed for: {}", input),
            model: "gpt-4".to_string(),
            provider: self.config.default_provider,
            tokens_used: TokenUsage::default(),
            reasoning_steps: Vec::new(),
            confidence: 0.85,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_config_default() {
        let config = OrchestrationConfig::default();
        assert_eq!(config.max_tokens, 4096);
        assert_eq!(config.default_provider, LlmProvider::OpenAI);
    }

    #[test]
    fn test_pipeline_builder() {
        let pipeline = PipelineBuilder::new()
            .with_thinktool("GigaThink")
            .with_thinktool("LaserLogic")
            .with_rag()
            .build()
            .unwrap();

        assert!(pipeline.rag_enabled);
        assert_eq!(pipeline.tools.len(), 2);
    }
}