use super::CommentaryStyle;
use crate::video::metadata::VideoFormat;
use crate::video::VideoMetadata;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommentaryInput {
pub video_metadata: VideoMetadata,
pub style: CommentaryStyle,
pub language: String,
pub max_length: Option<usize>,
pub min_length: Option<usize>,
pub custom_instructions: Option<String>,
pub include_timestamps: bool,
pub include_keywords: bool,
}
impl Default for CommentaryInput {
fn default() -> Self {
Self {
video_metadata: VideoMetadata::new(
"Test Video".to_string(),
chrono::Duration::seconds(0),
(1920, 1080),
VideoFormat::MP4,
),
style: CommentaryStyle::Professional,
language: "en".to_string(),
max_length: Some(1000),
min_length: Some(500),
custom_instructions: None,
include_timestamps: false,
include_keywords: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Commentary {
pub id: String,
pub video_id: String,
pub content: String,
pub style: CommentaryStyle,
pub language: String,
pub quality_score: f64,
pub keywords: Vec<String>,
pub timestamps: Vec<(f64, String)>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl Commentary {
pub fn new(
video_id: String,
content: String,
style: CommentaryStyle,
language: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
video_id,
content,
style,
language,
quality_score: 0.0,
keywords: Vec::new(),
timestamps: Vec::new(),
created_at: now,
updated_at: now,
}
}
pub fn update_timestamp(&mut self) {
self.updated_at = Utc::now();
}
pub fn set_quality_score(&mut self, score: f64) {
self.quality_score = score.clamp(0.0, 1.0);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommentaryOutput {
pub commentary: Commentary,
pub generation_time: f64,
pub tokens_used: u32,
pub truncated: bool,
pub suggested_improvements: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key: String,
pub api_endpoint: String,
pub default_model: String,
pub default_temperature: f32,
pub default_max_tokens: u32,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_key: String::new(),
api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
default_model: "gpt-3.5-turbo".to_string(),
default_temperature: 0.7,
default_max_tokens: 1000,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
#[serde(rename = "role")]
role: String,
#[serde(rename = "content")]
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionRequest {
#[serde(rename = "model")]
model: String,
#[serde(rename = "messages")]
messages: Vec<ChatMessage>,
#[serde(rename = "temperature")]
temperature: f32,
#[serde(rename = "max_tokens")]
max_tokens: u32,
#[serde(rename = "top_p")]
top_p: f32,
#[serde(rename = "frequency_penalty")]
frequency_penalty: f32,
#[serde(rename = "presence_penalty")]
presence_penalty: f32,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionResponse {
#[serde(rename = "id")]
id: String,
#[serde(rename = "object")]
object: String,
#[serde(rename = "created")]
created: u64,
#[serde(rename = "model")]
model: String,
#[serde(rename = "usage")]
usage: Usage,
#[serde(rename = "choices")]
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Usage {
#[serde(rename = "prompt_tokens")]
prompt_tokens: u32,
#[serde(rename = "completion_tokens")]
completion_tokens: u32,
#[serde(rename = "total_tokens")]
total_tokens: u32,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
#[serde(rename = "message")]
message: ChatMessage,
#[serde(rename = "finish_reason")]
finish_reason: String,
#[serde(rename = "index")]
index: u32,
}
#[derive(Clone)]
pub struct OpenAIGenerator {
config: OpenAIConfig,
client: reqwest::Client,
}
impl OpenAIGenerator {
pub fn new(api_key: String) -> Self {
let config = OpenAIConfig {
api_key,
..Default::default()
};
let client = reqwest::Client::new();
Self { config, client }
}
pub fn new_with_config(config: OpenAIConfig) -> Self {
let client = reqwest::Client::new();
Self { config, client }
}
fn build_prompt(&self, input: &CommentaryInput) -> String {
let style_desc = match input.style {
CommentaryStyle::Professional => "professional, formal, technical commentary",
CommentaryStyle::Casual => "relaxed, conversational, friendly style",
CommentaryStyle::Educational => "informative, teaching-focused, clear explanations",
CommentaryStyle::Entertaining => "engaging, humorous, entertaining style",
CommentaryStyle::Analytical => "data-driven, detailed analysis, objective",
CommentaryStyle::Storytelling => "narrative, storytelling approach, engaging",
CommentaryStyle::Poetic => "creative, poetic language, artistic",
CommentaryStyle::Technical => "highly technical, detailed explanations, precise",
};
let mut prompt = format!(
"Generate {} for a video with the following metadata:\n\n",
style_desc
);
prompt.push_str(&format!("Title: {}\n", input.video_metadata.title));
prompt.push_str(&format!(
"Duration: {} seconds\n",
input.video_metadata.duration.num_seconds()
));
prompt.push_str(&format!(
"Resolution: {}x{}\n",
input.video_metadata.resolution.0, input.video_metadata.resolution.1
));
prompt.push_str(&format!("Format: {:?}\n", input.video_metadata.format));
prompt.push_str(&format!(
"Frame rate: {:.2} fps\n",
input.video_metadata.frame_rate
));
prompt.push_str(&format!(
"Video codec: {}\n",
input.video_metadata.video_codec
));
prompt.push_str(&format!(
"Audio codec: {}\n",
input.video_metadata.audio_codec
));
if let Some(instructions) = &input.custom_instructions {
prompt.push_str(&format!("\nAdditional instructions: {}\n", instructions));
}
prompt.push_str(
"\nGenerate a comprehensive commentary that covers the key aspects of this video.",
);
if input.include_keywords {
prompt
.push_str(" Also, include 5-7 relevant keywords at the end, separated by commas.");
}
if input.include_timestamps {
prompt.push_str(" Include timestamps for key points in the format [0:00] Description.");
}
prompt
}
fn extract_keywords(&self, content: &str) -> Vec<String> {
if let Some(keywords_section) = content.split("Keywords: ").nth(1) {
keywords_section
.split(", ")
.map(|k| k.trim().to_string())
.collect()
} else {
Vec::new()
}
}
fn extract_timestamps(&self, content: &str) -> Vec<(f64, String)> {
let mut timestamps = Vec::new();
for line in content.lines() {
if let Some(timestamp_part) = line.split("[").nth(1) {
if let Some((time_str, desc)) = timestamp_part.split_once("] ") {
if let Some((minutes, seconds)) = time_str.split_once(":") {
if let (Ok(mins), Ok(secs)) =
(minutes.parse::<f64>(), seconds.parse::<f64>())
{
let total_seconds = mins * 60.0 + secs;
timestamps.push((total_seconds, desc.trim().to_string()));
}
}
}
}
}
timestamps
}
}
#[async_trait::async_trait]
impl super::traits::CommentaryGenerator for OpenAIGenerator {
async fn generate_commentary(&self, input: CommentaryInput) -> crate::Result<CommentaryOutput> {
let start_time = std::time::Instant::now();
let prompt = self.build_prompt(&input);
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are an AI assistant specialized in generating high-quality video commentaries.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt,
},
];
let request = ChatCompletionRequest {
model: self.config.default_model.clone(),
messages,
temperature: self.config.default_temperature,
max_tokens: self.config.default_max_tokens,
top_p: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
};
let response = self
.client
.post(&self.config.api_endpoint)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| crate::error::CoreError::NetworkError(e.to_string()))?;
let completion_response = response
.json::<ChatCompletionResponse>()
.await
.map_err(|e| crate::error::CoreError::JsonError(e.to_string()))?;
let choice = completion_response
.choices
.into_iter()
.next()
.ok_or_else(|| {
crate::error::CoreError::CommentaryError(
"No choices returned from OpenAI API".to_string(),
)
})?;
let content = choice.message.content;
let truncated = choice.finish_reason == "length";
let tokens_used = completion_response.usage.total_tokens;
let mut commentary = Commentary::new(
input.video_metadata.id.clone(),
content.clone(),
input.style,
input.language.clone(),
);
if input.include_keywords {
commentary.keywords = self.extract_keywords(&content);
}
if input.include_timestamps {
commentary.timestamps = self.extract_timestamps(&content);
}
let generation_time = start_time.elapsed().as_secs_f64();
let output = CommentaryOutput {
commentary,
generation_time,
tokens_used,
truncated,
suggested_improvements: Vec::new(),
};
Ok(output)
}
async fn evaluate_commentary(&self, commentary: &Commentary) -> crate::Result<f64> {
let prompt = format!(
"Evaluate the quality of the following video commentary on a scale of 0.0 to 1.0, \
where 1.0 is perfect. Consider relevance, clarity, engagement, and style appropriateness. \
Only return a single floating point number without any explanation.\n\n{}",
commentary.content
);
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are an AI assistant specialized in evaluating video commentaries."
.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt,
},
];
let request = ChatCompletionRequest {
model: self.config.default_model.clone(),
messages,
temperature: 0.0, max_tokens: 10,
top_p: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
};
let response = self
.client
.post(&self.config.api_endpoint)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| crate::error::CoreError::NetworkError(e.to_string()))?;
let completion_response = response
.json::<ChatCompletionResponse>()
.await
.map_err(|e| crate::error::CoreError::JsonError(e.to_string()))?;
let choice = completion_response
.choices
.into_iter()
.next()
.ok_or_else(|| {
crate::error::CoreError::CommentaryError(
"No choices returned from OpenAI API".to_string(),
)
})?;
let score_str = choice.message.content.trim();
let score = score_str.parse::<f64>().map_err(|e| {
crate::error::CoreError::CommentaryError(format!("Failed to parse score: {}", e))
})?;
Ok(score.clamp(0.0, 1.0))
}
async fn improve_commentary(
&self,
commentary: &Commentary,
feedback: &str,
) -> crate::Result<Commentary> {
let prompt = format!(
"Improve the following video commentary based on the feedback provided. \
Maintain the same style and language.\n\nCommentary:\n{}\n\nFeedback:\n{}",
commentary.content, feedback
);
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are an AI assistant specialized in improving video commentaries based on feedback.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt,
},
];
let request = ChatCompletionRequest {
model: self.config.default_model.clone(),
messages,
temperature: self.config.default_temperature,
max_tokens: self.config.default_max_tokens,
top_p: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
};
let response = self
.client
.post(&self.config.api_endpoint)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| crate::error::CoreError::NetworkError(e.to_string()))?;
let completion_response = response
.json::<ChatCompletionResponse>()
.await
.map_err(|e| crate::error::CoreError::JsonError(e.to_string()))?;
let choice = completion_response
.choices
.into_iter()
.next()
.ok_or_else(|| {
crate::error::CoreError::CommentaryError(
"No choices returned from OpenAI API".to_string(),
)
})?;
let improved_content = choice.message.content;
let mut improved_commentary = Commentary::new(
commentary.video_id.clone(),
improved_content.clone(),
commentary.style,
commentary.language.clone(),
);
improved_commentary.keywords = self.extract_keywords(&improved_content);
improved_commentary.timestamps = self.extract_timestamps(&improved_content);
Ok(improved_commentary)
}
async fn generate_multiple(
&self,
input: CommentaryInput,
styles: Vec<CommentaryStyle>,
) -> crate::Result<Vec<CommentaryOutput>> {
let mut handles = Vec::new();
for style in styles {
let mut input_with_style = input.clone();
input_with_style.style = style;
let generator = self.clone();
let handle =
tokio::spawn(async move { generator.generate_commentary(input_with_style).await });
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
let result = handle.await.map_err(|e| {
crate::error::CoreError::InternalError(format!("Task failed: {}", e))
})??;
results.push(result);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_commentary_input_default() {
let input = CommentaryInput::default();
assert_eq!(input.video_metadata.title, "Test Video");
assert_eq!(input.style, CommentaryStyle::Professional);
assert_eq!(input.language, "en");
assert_eq!(input.max_length, Some(1000));
assert_eq!(input.min_length, Some(500));
assert_eq!(input.include_timestamps, false);
assert_eq!(input.include_keywords, true);
}
#[test]
fn test_commentary_creation() {
let video_id = "test-video-id".to_string();
let content = "This is a test commentary".to_string();
let style = CommentaryStyle::Professional;
let language = "en".to_string();
let commentary =
Commentary::new(video_id.clone(), content.clone(), style, language.clone());
assert_eq!(commentary.video_id, video_id);
assert_eq!(commentary.content, content);
assert_eq!(commentary.style, style);
assert_eq!(commentary.language, language);
assert_eq!(commentary.quality_score, 0.0);
assert!(commentary.keywords.is_empty());
assert!(commentary.timestamps.is_empty());
assert_eq!(commentary.created_at, commentary.updated_at);
}
#[test]
fn test_commentary_update_timestamp() {
let video_id = "test-video-id".to_string();
let content = "This is a test commentary".to_string();
let style = CommentaryStyle::Professional;
let language = "en".to_string();
let mut commentary = Commentary::new(video_id, content, style, language);
let old_updated_at = commentary.updated_at;
std::thread::sleep(std::time::Duration::from_millis(10));
commentary.update_timestamp();
assert!(commentary.updated_at > old_updated_at);
}
#[test]
fn test_commentary_set_quality_score() {
let video_id = "test-video-id".to_string();
let content = "This is a test commentary".to_string();
let style = CommentaryStyle::Professional;
let language = "en".to_string();
let mut commentary = Commentary::new(video_id, content, style, language);
commentary.set_quality_score(0.75);
assert_eq!(commentary.quality_score, 0.75);
commentary.set_quality_score(-0.5);
assert_eq!(commentary.quality_score, 0.0);
commentary.set_quality_score(1.5);
assert_eq!(commentary.quality_score, 1.0);
}
#[test]
fn test_commentary_serialization() {
let video_id = "test-video-id".to_string();
let content = "This is a test commentary".to_string();
let style = CommentaryStyle::Professional;
let language = "en".to_string();
let commentary = Commentary::new(video_id, content, style, language);
let json = serde_json::to_string(&commentary).unwrap();
let deserialized: Commentary = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, commentary.id);
assert_eq!(deserialized.video_id, commentary.video_id);
assert_eq!(deserialized.content, commentary.content);
assert_eq!(deserialized.style, commentary.style);
assert_eq!(deserialized.language, commentary.language);
}
#[test]
fn test_commentary_output_serialization() {
let video_id = "test-video-id".to_string();
let content = "This is a test commentary".to_string();
let style = CommentaryStyle::Professional;
let language = "en".to_string();
let commentary = Commentary::new(video_id, content, style, language);
let output = CommentaryOutput {
commentary,
generation_time: 2.5,
tokens_used: 100,
truncated: false,
suggested_improvements: vec![
"Add more details".to_string(),
"Improve flow".to_string(),
],
};
let json = serde_json::to_string(&output).unwrap();
let deserialized: CommentaryOutput = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.generation_time, output.generation_time);
assert_eq!(deserialized.tokens_used, output.tokens_used);
assert_eq!(deserialized.truncated, output.truncated);
assert_eq!(
deserialized.suggested_improvements,
output.suggested_improvements
);
}
}