openserve 2.0.3

A modern, high-performance, AI-enhanced file server built in Rust
Documentation
//! File classification using AI

use anyhow::Result;
use async_openai::{
    types::{
        ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
        ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
    },
    Client,
};
use serde::{Deserialize, Serialize};
use tracing::debug;

use std::sync::Arc;

use crate::config::AiConfig;

/// Represents the result of a file classification operation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Classification {
    /// The primary category of the file (e.g., "document", "code").
    pub category: String,
    /// An optional subcategory for more specific classification (e.g., "rust", "python").
    pub subcategory: Option<String>,
    /// The confidence score of the classification.
    pub confidence: f32,
    /// A brief description of the classification.
    pub description: String,
}

/// Classify file content
pub async fn classify_content(
    client: &Client<async_openai::config::OpenAIConfig>,
    content: &str,
    filename: &str,
    model: &str,
) -> Result<Classification> {
    let system_prompt = r#"
You are a file classification expert. Classify the given file content into one of these categories:

Categories:
- document: Text documents, PDFs, presentations
- code: Source code, scripts, configuration files
- data: CSV, JSON, XML, databases
- media: Images, videos, audio files
- archive: ZIP, TAR, compressed files
- other: Files that don't fit other categories

Respond with JSON format:
{
  "category": "category_name",
  "subcategory": "optional_subcategory",
  "confidence": 0.95,
  "description": "Brief description"
}
"#;

    let request = CreateChatCompletionRequestArgs::default()
        .model(model)
        .messages([
            ChatCompletionRequestMessage::System(
                ChatCompletionRequestSystemMessageArgs::default()
                    .content(system_prompt)
                    .build()?,
            ),
            ChatCompletionRequestMessage::User(
                ChatCompletionRequestUserMessageArgs::default()
                    .content(format!("Filename: {}\nContent preview:\n{}", 
                        filename, 
                        &content[..content.len().min(1000)]
                    ))
                    .build()?,
            ),
        ])
        .temperature(0.1)
        .max_tokens(200u32)
        .build()?;

    let response = client.chat().create(request).await?;
    
    let content = response.choices[0].message.content.as_ref()
        .ok_or_else(|| anyhow::anyhow!("No response from AI"))?;
    
    let classification: Classification = serde_json::from_str(content)?;
    Ok(classification)
}

/// Suggest file organization based on classification
pub async fn suggest_organization(
    classifications: &[(&str, Classification)],
) -> Vec<String> {
    let mut suggestions = Vec::new();
    
    // Group by category
    let mut categories = std::collections::HashMap::new();
    for (path, classification) in classifications {
        categories.entry(&classification.category)
            .or_insert_with(Vec::new)
            .push(path);
    }
    
    // Generate folder suggestions
    for (category, files) in categories {
        if files.len() >= 2 { // Lower threshold for testing
            suggestions.push(format!("Create '{}' folder for {} files", category, files.len()));
        }
    }
    
    // If no suggestions based on categories, suggest based on file types
    if suggestions.is_empty() {
        let mut extensions = std::collections::HashMap::new();
        for (path, _) in classifications {
            if let Some(ext) = std::path::Path::new(path).extension() {
                let ext_str = ext.to_string_lossy().to_string();
                extensions.entry(ext_str)
                    .or_insert_with(Vec::new)
                    .push(path);
            }
        }
        
        for (ext, files) in extensions {
            if files.len() >= 2 {
                suggestions.push(format!("Group {} files with '{}' extension", files.len(), ext));
            }
        }
    }
    
    suggestions
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_suggest_organization() {
        let classifications = vec![
            ("file1.rs", Classification {
                category: "code".to_string(),
                subcategory: Some("rust".to_string()),
                confidence: 0.95,
                description: "Rust source code".to_string(),
            }),
            ("file2.py", Classification {
                category: "code".to_string(),
                subcategory: Some("python".to_string()),
                confidence: 0.90,
                description: "Python script".to_string(),
            }),
            ("doc.txt", Classification {
                category: "document".to_string(),
                subcategory: None,
                confidence: 0.85,
                description: "Text document".to_string(),
            }),
        ];
        
        let suggestions = suggest_organization(&classifications).await;
        assert!(!suggestions.is_empty());
    }
}

/// A service for performing content classification tasks.
#[derive(Debug, Clone)]
pub struct ClassifierService {
    client: async_openai::Client<async_openai::config::OpenAIConfig>,
    config: Arc<AiConfig>,
}

impl ClassifierService {
    /// Creates a new `ClassifierService` instance.
    pub fn new(config: Arc<AiConfig>) -> Result<Self> {
        let client = async_openai::Client::new();
        Ok(Self { client, config })
    }

    /// Classify text into categories using AI
    pub async fn classify_text(&self, text: &str, categories: &[String]) -> Result<String> {
        debug!("Classifying text into {} categories", categories.len());

        let categories_str = categories.join(", ");
        
        let request = async_openai::types::CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                async_openai::types::ChatCompletionRequestMessage::System(
                    async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
                            format!("You are a text classifier. Classify the given text into one of these categories: {}. Respond with only the category name.", categories_str)
                        ))
                        .build()?
                ),
                async_openai::types::ChatCompletionRequestMessage::User(
                    async_openai::types::ChatCompletionRequestUserMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
                        .build()?
                )
            ])
            .max_tokens(50u32)
            .temperature(0.1)
            .build()?;

        let response = self.client.chat().create(request).await?;

        Ok(response.choices[0].message.content.as_ref()
            .unwrap_or(&"Unknown".to_string())
            .clone())
    }

    /// Detect the language of the given text
    pub async fn detect_language(&self, text: &str) -> Result<String> {
        debug!("Detecting language for text");

        let request = async_openai::types::CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                async_openai::types::ChatCompletionRequestMessage::System(
                    async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
                            "You are a language detection expert. Identify the language of the given text. Respond with only the language name in English.".to_string()
                        ))
                        .build()?
                ),
                async_openai::types::ChatCompletionRequestMessage::User(
                    async_openai::types::ChatCompletionRequestUserMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
                        .build()?
                )
            ])
            .max_tokens(20u32)
            .temperature(0.1)
            .build()?;

        let response = self.client.chat().create(request).await?;

        Ok(response.choices[0].message.content.as_ref()
            .unwrap_or(&"Unknown".to_string())
            .clone())
    }

    /// Perform sentiment analysis on the given text
    pub async fn sentiment_analysis(&self, text: &str) -> Result<String> {
        debug!("Analyzing sentiment for text");

        let request = async_openai::types::CreateChatCompletionRequestArgs::default()
            .model(&self.config.model)
            .messages([
                async_openai::types::ChatCompletionRequestMessage::System(
                    async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
                            "You are a sentiment analysis expert. Analyze the sentiment of the given text. Respond with only one word: 'positive', 'negative', or 'neutral'.".to_string()
                        ))
                        .build()?
                ),
                async_openai::types::ChatCompletionRequestMessage::User(
                    async_openai::types::ChatCompletionRequestUserMessageArgs::default()
                        .content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
                        .build()?
                )
            ])
            .max_tokens(10u32)
            .temperature(0.1)
            .build()?;

        let response = self.client.chat().create(request).await?;

        Ok(response.choices[0].message.content.as_ref()
            .unwrap_or(&"neutral".to_string())
            .clone())
    }
}