use anyhow::Result;
use async_openai::{
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
},
Client,
};
use serde::{Deserialize, Serialize};
use tracing::debug;
use std::sync::Arc;
use crate::config::AiConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Classification {
pub category: String,
pub subcategory: Option<String>,
pub confidence: f32,
pub description: String,
}
pub async fn classify_content(
client: &Client<async_openai::config::OpenAIConfig>,
content: &str,
filename: &str,
model: &str,
) -> Result<Classification> {
let system_prompt = r#"
You are a file classification expert. Classify the given file content into one of these categories:
Categories:
- document: Text documents, PDFs, presentations
- code: Source code, scripts, configuration files
- data: CSV, JSON, XML, databases
- media: Images, videos, audio files
- archive: ZIP, TAR, compressed files
- other: Files that don't fit other categories
Respond with JSON format:
{
"category": "category_name",
"subcategory": "optional_subcategory",
"confidence": 0.95,
"description": "Brief description"
}
"#;
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages([
ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessageArgs::default()
.content(system_prompt)
.build()?,
),
ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessageArgs::default()
.content(format!("Filename: {}\nContent preview:\n{}",
filename,
&content[..content.len().min(1000)]
))
.build()?,
),
])
.temperature(0.1)
.max_tokens(200u32)
.build()?;
let response = client.chat().create(request).await?;
let content = response.choices[0].message.content.as_ref()
.ok_or_else(|| anyhow::anyhow!("No response from AI"))?;
let classification: Classification = serde_json::from_str(content)?;
Ok(classification)
}
pub async fn suggest_organization(
classifications: &[(&str, Classification)],
) -> Vec<String> {
let mut suggestions = Vec::new();
let mut categories = std::collections::HashMap::new();
for (path, classification) in classifications {
categories.entry(&classification.category)
.or_insert_with(Vec::new)
.push(path);
}
for (category, files) in categories {
if files.len() >= 2 { suggestions.push(format!("Create '{}' folder for {} files", category, files.len()));
}
}
if suggestions.is_empty() {
let mut extensions = std::collections::HashMap::new();
for (path, _) in classifications {
if let Some(ext) = std::path::Path::new(path).extension() {
let ext_str = ext.to_string_lossy().to_string();
extensions.entry(ext_str)
.or_insert_with(Vec::new)
.push(path);
}
}
for (ext, files) in extensions {
if files.len() >= 2 {
suggestions.push(format!("Group {} files with '{}' extension", files.len(), ext));
}
}
}
suggestions
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_suggest_organization() {
let classifications = vec![
("file1.rs", Classification {
category: "code".to_string(),
subcategory: Some("rust".to_string()),
confidence: 0.95,
description: "Rust source code".to_string(),
}),
("file2.py", Classification {
category: "code".to_string(),
subcategory: Some("python".to_string()),
confidence: 0.90,
description: "Python script".to_string(),
}),
("doc.txt", Classification {
category: "document".to_string(),
subcategory: None,
confidence: 0.85,
description: "Text document".to_string(),
}),
];
let suggestions = suggest_organization(&classifications).await;
assert!(!suggestions.is_empty());
}
}
#[derive(Debug, Clone)]
pub struct ClassifierService {
client: async_openai::Client<async_openai::config::OpenAIConfig>,
config: Arc<AiConfig>,
}
impl ClassifierService {
pub fn new(config: Arc<AiConfig>) -> Result<Self> {
let client = async_openai::Client::new();
Ok(Self { client, config })
}
pub async fn classify_text(&self, text: &str, categories: &[String]) -> Result<String> {
debug!("Classifying text into {} categories", categories.len());
let categories_str = categories.join(", ");
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model(&self.config.model)
.messages([
async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
format!("You are a text classifier. Classify the given text into one of these categories: {}. Respond with only the category name.", categories_str)
))
.build()?
),
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
.build()?
)
])
.max_tokens(50u32)
.temperature(0.1)
.build()?;
let response = self.client.chat().create(request).await?;
Ok(response.choices[0].message.content.as_ref()
.unwrap_or(&"Unknown".to_string())
.clone())
}
pub async fn detect_language(&self, text: &str) -> Result<String> {
debug!("Detecting language for text");
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model(&self.config.model)
.messages([
async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
"You are a language detection expert. Identify the language of the given text. Respond with only the language name in English.".to_string()
))
.build()?
),
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
.build()?
)
])
.max_tokens(20u32)
.temperature(0.1)
.build()?;
let response = self.client.chat().create(request).await?;
Ok(response.choices[0].message.content.as_ref()
.unwrap_or(&"Unknown".to_string())
.clone())
}
pub async fn sentiment_analysis(&self, text: &str) -> Result<String> {
debug!("Analyzing sentiment for text");
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model(&self.config.model)
.messages([
async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
"You are a sentiment analysis expert. Analyze the sentiment of the given text. Respond with only one word: 'positive', 'negative', or 'neutral'.".to_string()
))
.build()?
),
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessageArgs::default()
.content(async_openai::types::ChatCompletionRequestUserMessageContent::Text(text.to_string()))
.build()?
)
])
.max_tokens(10u32)
.temperature(0.1)
.build()?;
let response = self.client.chat().create(request).await?;
Ok(response.choices[0].message.content.as_ref()
.unwrap_or(&"neutral".to_string())
.clone())
}
}