oris-runtime 0.15.0

An agentic workflow runtime and programmable AI execution system in Rust: stateful graphs, agents, tools, and multi-step execution.
use std::sync::Arc;

use async_trait::async_trait;

use crate::schemas::{Document, Retriever};

/// Configuration for FlashRank reranker
#[derive(Debug, Clone)]
pub struct FlashRankRerankerConfig {
    /// Model name to use (e.g., "ms-marco-MiniLM-L-12-v2")
    pub model: String,
    /// Top K documents to return after reranking
    pub top_k: Option<usize>,
}

impl Default for FlashRankRerankerConfig {
    fn default() -> Self {
        Self {
            model: "ms-marco-MiniLM-L-12-v2".to_string(),
            top_k: None,
        }
    }
}

/// FlashRank reranker that uses local ONNX models for reranking
///
/// Note: This is a placeholder implementation. Full implementation would require:
/// - ONNX Runtime integration (ort crate)
/// - Model loading and inference
/// - Tokenization support
///
/// For now, this provides a simple similarity-based reranking as a fallback.
pub struct FlashRankReranker {
    base_retriever: Arc<dyn Retriever>,
    config: FlashRankRerankerConfig,
}

impl FlashRankReranker {
    /// Create a new FlashRank reranker
    pub fn new(base_retriever: Arc<dyn Retriever>) -> Self {
        Self::with_config(base_retriever, FlashRankRerankerConfig::default())
    }

    /// Create a new FlashRank reranker with custom config
    pub fn with_config(
        base_retriever: Arc<dyn Retriever>,
        config: FlashRankRerankerConfig,
    ) -> Self {
        Self {
            base_retriever,
            config,
        }
    }

    /// Simple reranking based on query-document similarity
    /// In a full implementation, this would use the FlashRank ONNX model
    fn rerank_simple(&self, query: &str, documents: Vec<Document>) -> Vec<Document> {
        // Simple keyword-based scoring as placeholder
        let query_lower = query.to_lowercase();
        let query_words: Vec<&str> = query_lower.split_whitespace().collect();

        let mut scored: Vec<(Document, f64)> = documents
            .into_iter()
            .map(|doc| {
                let doc_lower = doc.page_content.to_lowercase();
                let score = query_words
                    .iter()
                    .map(|word| if doc_lower.contains(word) { 1.0 } else { 0.0 })
                    .sum::<f64>()
                    / query_words.len() as f64;
                (doc, score)
            })
            .collect();

        // Sort by score (descending)
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        // Apply top_k if configured
        let results: Vec<Document> = scored.into_iter().map(|(doc, _)| doc).collect();
        if let Some(k) = self.config.top_k {
            results.into_iter().take(k).collect()
        } else {
            results
        }
    }
}

#[async_trait]
impl Retriever for FlashRankReranker {
    async fn get_relevant_documents(
        &self,
        query: &str,
    ) -> Result<Vec<Document>, crate::error::RetrieverError> {
        // Get documents from base retriever
        let documents = self.base_retriever.get_relevant_documents(query).await?;

        // Rerank documents (using simple method for now)
        // TODO: Integrate actual FlashRank ONNX model when ort crate is available
        Ok(self.rerank_simple(query, documents))
    }
}