openserve 2.0.3

A modern, high-performance, AI-enhanced file server built in Rust
Documentation
//! AI Module - Complete Implementation

use anyhow::Result;
use async_openai::{
    types::{
        ChatCompletionRequestSystemMessageArgs,
        ChatCompletionRequestUserMessageArgs,
        CreateChatCompletionRequestArgs,
        CreateEmbeddingRequestArgs,
        ChatCompletionRequestMessage,
    },
    Client,
};
use sha2::Digest;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};

pub mod embeddings;
pub mod search;
pub mod summarizer;
pub mod classifier;

use crate::config::AiConfig;
pub use embeddings::*;
pub use summarizer::*;
pub use classifier::*;

/// AI Service for intelligent file operations
#[derive(Clone)]
pub struct AiService {
    client: Client<async_openai::config::OpenAIConfig>,
    config: Arc<AiConfig>,
    embeddings_cache: Arc<RwLock<moka::future::Cache<String, Vec<f32>>>>,
}

impl AiService {
    /// Create a new AI service instance
    pub async fn new(config: Arc<AiConfig>) -> Result<Self> {
        info!("Initializing AI service with model: {}", config.model);
        
        let client = if !config.api_key.is_empty() {
            Client::with_config(
                async_openai::config::OpenAIConfig::new()
                                          .with_api_key(&config.api_key)
            )
        } else {
            warn!("No OpenAI API key provided, AI features will be limited");
            return Err(anyhow::anyhow!("OpenAI API key required"));
        };

        let embeddings_cache = Arc::new(RwLock::new(
            moka::future::Cache::builder()
                .max_capacity(10_000)
                .time_to_live(std::time::Duration::from_secs(3600))
                .build()
        ));

        Ok(Self {
            client,
            config,
            embeddings_cache,
        })
    }

    /// Analyze file content and generate metadata
    pub async fn analyze_content(&self, content: &str, filename: &str) -> Result<ContentAnalysis> {
        debug!("Analyzing content for file: {}", filename);
        
        let request = CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                ChatCompletionRequestMessage::System(
                    ChatCompletionRequestSystemMessageArgs::default()
                        .content("You are a file content analyzer. Analyze the given content and provide: 
                        1. A brief summary (max 200 chars)
                        2. Key topics/tags (comma-separated)
                        3. Content type classification
                        4. Language detection
                        Return as JSON with fields: summary, tags, type, language")
                        .build()?
                ),
                ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessageArgs::default()
                        .content(format!("Filename: {}\nContent:\n{}", filename, 
                            &content[..content.len().min(2000)]))
                        .build()?
                ),
            ])
            .temperature(0.3)
            .max_tokens(300u32)
            .build()?;

        let response = self.client.chat().create(request).await?;
        
        let analysis_text = response.choices[0].message.content.as_ref()
            .ok_or_else(|| anyhow::anyhow!("No response from AI"))?;
        
        let analysis: ContentAnalysis = serde_json::from_str(analysis_text)
            .map_err(|e| anyhow::anyhow!("Failed to parse AI response: {}", e))?;
        
        Ok(analysis)
    }

    /// Generate embeddings for semantic search
    pub async fn generate_embeddings(&self, text: &str) -> Result<Vec<f32>> {
        // Check cache first
        let cache_key = format!("emb:{}", hex::encode(sha2::Sha256::digest(text.as_bytes())));
        
        if let Some(cached) = self.embeddings_cache.read().await.get(&cache_key).await {
            return Ok(cached);
        }

        debug!("Generating new embeddings");
        let request = CreateEmbeddingRequestArgs::default()
            .model("text-embedding-3-small")
            .input([text])
            .build()?;

        let response = self.client.embeddings().create(request).await?;
        let embedding = response.data[0].embedding.clone();
        
        // Cache the result
        self.embeddings_cache.write().await
            .insert(cache_key, embedding.clone()).await;
        
        Ok(embedding)
    }

    /// Smart file organization suggestions
    pub async fn suggest_organization(&self, files: &[FileInfo]) -> Result<OrganizationSuggestion> {
        let file_list = files.iter()
            .take(100) // Limit to prevent token overflow
            .map(|f| format!("{} ({})", f.name, f.size_human))
            .collect::<Vec<_>>()
            .join("\n");

        let request = CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                ChatCompletionRequestMessage::System(
                    ChatCompletionRequestSystemMessageArgs::default()
                        .content("You are a file organization expert. Suggest folder structure and categorization for the given files. Return JSON with: folders (array of suggested folders with descriptions), rules (organization rules)")
                        .build()?
                ),
                ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessageArgs::default()
                        .content(format!("Files to organize:\n{}", file_list))
                        .build()?
                ),
            ])
            .temperature(0.5)
            .max_tokens(500u32)
            .build()?;

        let response = self.client.chat().create(request).await?;
        let suggestion_text = response.choices[0].message.content.as_ref()
            .ok_or_else(|| anyhow::anyhow!("No response from AI"))?;
        
        let suggestion: OrganizationSuggestion = serde_json::from_str(suggestion_text)
            .map_err(|e| anyhow::anyhow!("Failed to parse organization suggestion: {}", e))?;
        
        Ok(suggestion)
    }

    /// Chat with files using natural language
    pub async fn chat(&self, message: &str, context: &str) -> Result<String> {
        debug!("Processing chat message: {}", message);

        let system_prompt = if context.is_empty() {
            "You are a helpful AI assistant for file management. Answer questions about files and help with organization."
        } else {
            "You are a helpful AI assistant. Use the provided file context to answer questions accurately."
        };

        let user_content = if context.is_empty() {
            message.to_string()
        } else {
            format!("Context:\n{}\n\nQuestion: {}", context, message)
        };

        let request = CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                ChatCompletionRequestMessage::System(
                    ChatCompletionRequestSystemMessageArgs::default()
                        .content(system_prompt)
                        .build()?
                ),
                ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessageArgs::default()
                        .content(user_content)
                        .build()?
                ),
            ])
            .temperature(self.config.temperature)
            .max_tokens(self.config.max_tokens)
            .build()?;

        let response = self.client.chat().create(request).await?;
        
        let response_text = response.choices[0].message.content.as_ref()
            .ok_or_else(|| anyhow::anyhow!("No response from AI"))?
            .clone();
        
        Ok(response_text)
    }

    /// Extract entities from text
    pub async fn extract_entities(&self, content: &str) -> Result<Vec<Entity>> {
        debug!("Extracting entities from content");

        let request = CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                ChatCompletionRequestMessage::System(
                    ChatCompletionRequestSystemMessageArgs::default()
                        .content("Extract named entities from the text. Return JSON array with objects containing 'text', 'type' (PERSON, ORG, LOCATION, DATE, etc.), and 'confidence' fields.")
                        .build()?
                ),
                ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessageArgs::default()
                        .content(content)
                        .build()?
                ),
            ])
            .temperature(0.1)
            .max_tokens(400u32)
            .build()?;

        let response = self.client.chat().create(request).await?;
        
        let entities_text = response.choices[0].message.content.as_ref()
            .ok_or_else(|| anyhow::anyhow!("No response from AI"))?;
        
        let entities: Vec<Entity> = serde_json::from_str(entities_text)
            .map_err(|e| anyhow::anyhow!("Failed to parse entities: {}", e))?;
        
        Ok(entities)
    }
}

/// Represents the result of a content analysis operation.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ContentAnalysis {
    /// A brief summary of the content.
    pub summary: String,
    /// A list of relevant tags or keywords.
    pub tags: Vec<String>,
    /// The detected content type (e.g., "code", "document").
    pub content_type: String,
    /// The detected language of the content.
    pub language: String,
}

/// Represents a named entity extracted from text.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Entity {
    /// The text of the extracted entity.
    pub text: String,
    /// The type of the entity (e.g., "PERSON", "ORG").
    pub entity_type: String,
    /// The confidence score of the entity extraction.
    pub confidence: f32,
}

/// Represents basic information about a file for organization suggestions.
#[derive(Debug, Clone)]
pub struct FileInfo {
    /// The name of the file.
    pub name: String,
    /// The size of the file in bytes.
    pub size: u64,
    /// A human-readable representation of the file size.
    pub size_human: String,
    /// The last modified timestamp of the file.
    pub modified: chrono::DateTime<chrono::Utc>,
}

/// Represents a suggestion for organizing files.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OrganizationSuggestion {
    /// A list of suggested folders.
    pub folders: Vec<FolderSuggestion>,
    /// A list of suggested organization rules.
    pub rules: Vec<String>,
}

/// Represents a suggested folder in an organization suggestion.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FolderSuggestion {
    /// The name of the suggested folder.
    pub name: String,
    /// A description of the folder's purpose.
    pub description: String,
    /// A list of file patterns that should be moved to this folder.
    pub file_patterns: Vec<String>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_ai_service_creation() {
        let config = Arc::new(AiConfig {
            enabled: true,
            api_key: "test-key".to_string(),
            model: "gpt-4o-mini".to_string(),
            max_tokens: 1000,
            temperature: 0.7,
            timeout: 30,
        });

        let service = AiService::new(config).await;
        assert!(service.is_ok());
    }

    #[test]
    fn test_content_analysis_serde() {
        let analysis = ContentAnalysis {
            summary: "Test summary".to_string(),
            tags: vec!["test".to_string(), "rust".to_string()],
            content_type: "code".to_string(),
            language: "rust".to_string(),
        };

        let json = serde_json::to_string(&analysis).unwrap();
        let deserialized: ContentAnalysis = serde_json::from_str(&json).unwrap();
        
        assert_eq!(analysis.summary, deserialized.summary);
        assert_eq!(analysis.tags, deserialized.tags);
    }
}