use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
use thiserror::Error;
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum FinetuningError {
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
#[error("Dataset not found: {0}")]
DatasetNotFound(String),
#[error("Insufficient training data: need {needed}, have {available}")]
InsufficientData { needed: usize, available: usize },
#[error("Training job not found: {0}")]
JobNotFound(String),
#[error("Invalid configuration: {0}")]
ConfigurationError(String),
#[error("Data quality error: {0}")]
DataQualityError(String),
}
pub type FinetuningResult<T> = Result<T, FinetuningError>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExportFormat {
OpenAiChatJsonl,
HuggingFaceInstruct,
AlpacaInstruct,
ConversationPairs,
FullConversationJson,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
pub id: String,
pub system_prompt: Option<String>,
pub human_message: String,
pub assistant_response: String,
pub domain: Option<String>,
pub quality_score: f32,
pub human_reviewed: bool,
pub user_rating: Option<u8>,
pub tags: Vec<String>,
pub source_session_id: Option<String>,
pub collected_at: DateTime<Utc>,
pub metadata: HashMap<String, String>,
}
impl TrainingExample {
pub fn new(human_message: String, assistant_response: String) -> Self {
Self {
id: Uuid::new_v4().to_string(),
system_prompt: None,
human_message,
assistant_response,
domain: None,
quality_score: 0.5,
human_reviewed: false,
user_rating: None,
tags: Vec::new(),
source_session_id: None,
collected_at: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn meets_quality_threshold(&self, threshold: f32) -> bool {
self.quality_score >= threshold
}
pub fn to_openai_format(&self) -> serde_json::Value {
let mut messages = Vec::new();
if let Some(ref system) = self.system_prompt {
messages.push(serde_json::json!({
"role": "system",
"content": system
}));
}
messages.push(serde_json::json!({
"role": "user",
"content": self.human_message
}));
messages.push(serde_json::json!({
"role": "assistant",
"content": self.assistant_response
}));
serde_json::json!({ "messages": messages })
}
pub fn to_alpaca_format(&self) -> serde_json::Value {
serde_json::json!({
"instruction": self.human_message,
"input": "",
"output": self.assistant_response
})
}
pub fn compute_quality_score(&mut self) {
let mut score = 0.5f32;
let response_len = self.assistant_response.len();
if response_len > 100 {
score += 0.1;
}
if response_len > 500 {
score += 0.1;
}
if let Some(rating) = self.user_rating {
score += (rating as f32 - 3.0) * 0.1;
}
if self.human_reviewed {
score += 0.2;
}
if response_len < 20 {
score -= 0.3;
}
if self.human_message.len() < 10 {
score -= 0.1;
}
self.quality_score = score.clamp(0.0, 1.0);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingDataset {
pub id: String,
pub name: String,
pub description: String,
pub domain: Option<String>,
pub examples: Vec<TrainingExample>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub version: String,
pub tags: Vec<String>,
pub stats: DatasetStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStats {
pub total_examples: usize,
pub high_quality_examples: usize,
pub human_reviewed_examples: usize,
pub average_quality_score: f32,
pub average_human_len: f32,
pub average_assistant_len: f32,
pub unique_domains: usize,
}
impl TrainingDataset {
pub fn new(id: String, name: String) -> Self {
let now = Utc::now();
Self {
id,
name,
description: String::new(),
domain: None,
examples: Vec::new(),
created_at: now,
updated_at: now,
version: "1.0.0".to_string(),
tags: Vec::new(),
stats: DatasetStats {
total_examples: 0,
high_quality_examples: 0,
human_reviewed_examples: 0,
average_quality_score: 0.0,
average_human_len: 0.0,
average_assistant_len: 0.0,
unique_domains: 0,
},
}
}
pub fn add_example(&mut self, mut example: TrainingExample) {
if (example.quality_score - 0.5).abs() < f32::EPSILON {
example.compute_quality_score();
}
self.examples.push(example);
self.updated_at = Utc::now();
self.recompute_stats();
}
pub fn examples_above_quality(&self, threshold: f32) -> Vec<&TrainingExample> {
self.examples
.iter()
.filter(|e| e.quality_score >= threshold)
.collect()
}
pub fn deduplicate(&mut self) -> usize {
let before = self.examples.len();
let mut seen_messages = std::collections::HashSet::new();
self.examples.retain(|e| {
let key = e.human_message.trim().to_lowercase();
seen_messages.insert(key)
});
let removed = before - self.examples.len();
if removed > 0 {
info!("Deduplicated {} examples from dataset {}", removed, self.id);
self.recompute_stats();
}
removed
}
pub fn export_jsonl(
&self,
format: &ExportFormat,
quality_threshold: f32,
) -> FinetuningResult<String> {
let examples = self.examples_above_quality(quality_threshold);
if examples.is_empty() {
return Err(FinetuningError::InsufficientData {
needed: 1,
available: 0,
});
}
let mut lines = Vec::with_capacity(examples.len());
for example in &examples {
let json = match format {
ExportFormat::OpenAiChatJsonl => example.to_openai_format(),
ExportFormat::AlpacaInstruct => example.to_alpaca_format(),
ExportFormat::HuggingFaceInstruct => serde_json::json!({
"prompt": example.human_message,
"response": example.assistant_response,
}),
ExportFormat::ConversationPairs => serde_json::json!({
"input": example.human_message,
"output": example.assistant_response,
}),
ExportFormat::FullConversationJson => serde_json::to_value(example)?,
};
lines.push(serde_json::to_string(&json)?);
}
Ok(lines.join("\n"))
}
pub fn recompute_stats(&mut self) {
let total = self.examples.len();
if total == 0 {
self.stats = DatasetStats {
total_examples: 0,
high_quality_examples: 0,
human_reviewed_examples: 0,
average_quality_score: 0.0,
average_human_len: 0.0,
average_assistant_len: 0.0,
unique_domains: 0,
};
return;
}
let high_quality = self
.examples
.iter()
.filter(|e| e.quality_score >= 0.7)
.count();
let reviewed = self.examples.iter().filter(|e| e.human_reviewed).count();
let total_quality: f32 = self.examples.iter().map(|e| e.quality_score).sum();
let total_human_len: usize = self.examples.iter().map(|e| e.human_message.len()).sum();
let total_assistant_len: usize = self
.examples
.iter()
.map(|e| e.assistant_response.len())
.sum();
let unique_domains: std::collections::HashSet<&str> = self
.examples
.iter()
.filter_map(|e| e.domain.as_deref())
.collect();
self.stats = DatasetStats {
total_examples: total,
high_quality_examples: high_quality,
human_reviewed_examples: reviewed,
average_quality_score: total_quality / total as f32,
average_human_len: total_human_len as f32 / total as f32,
average_assistant_len: total_assistant_len as f32 / total as f32,
unique_domains: unique_domains.len(),
};
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum JobStatus {
Pending,
PreparingData,
Training,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FinetuningJobConfig {
pub base_model: String,
pub suffix: String,
pub epochs: u32,
pub batch_size: u32,
pub learning_rate_multiplier: f32,
pub quality_threshold: f32,
pub export_format: ExportFormat,
pub max_examples: Option<usize>,
pub run_validation: bool,
pub validation_split: f32,
}
impl Default for FinetuningJobConfig {
fn default() -> Self {
Self {
base_model: "gpt-3.5-turbo".to_string(),
suffix: "oxirs-chat".to_string(),
epochs: 3,
batch_size: 4,
learning_rate_multiplier: 1.0,
quality_threshold: 0.6,
export_format: ExportFormat::OpenAiChatJsonl,
max_examples: None,
run_validation: true,
validation_split: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FinetuningJob {
pub id: String,
pub dataset_id: String,
pub config: FinetuningJobConfig,
pub status: JobStatus,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub external_job_id: Option<String>,
pub result_model: Option<String>,
pub training_examples: usize,
pub validation_examples: usize,
pub current_step: Option<u64>,
pub total_steps: Option<u64>,
pub training_loss: Vec<f32>,
pub validation_loss: Vec<f32>,
pub error_message: Option<String>,
pub metadata: HashMap<String, String>,
}
impl FinetuningJob {
pub fn new(dataset_id: String, config: FinetuningJobConfig) -> Self {
Self {
id: Uuid::new_v4().to_string(),
dataset_id,
config,
status: JobStatus::Pending,
created_at: Utc::now(),
started_at: None,
completed_at: None,
external_job_id: None,
result_model: None,
training_examples: 0,
validation_examples: 0,
current_step: None,
total_steps: None,
training_loss: Vec::new(),
validation_loss: Vec::new(),
error_message: None,
metadata: HashMap::new(),
}
}
pub fn start(&mut self) {
self.status = JobStatus::Training;
self.started_at = Some(Utc::now());
info!("Fine-tuning job {} started", self.id);
}
pub fn record_step(&mut self, step: u64, train_loss: f32, val_loss: Option<f32>) {
self.current_step = Some(step);
self.training_loss.push(train_loss);
if let Some(vl) = val_loss {
self.validation_loss.push(vl);
}
debug!(
"Job {} step {}: train_loss={:.4}",
self.id, step, train_loss
);
}
pub fn complete(&mut self, result_model: String) {
self.status = JobStatus::Completed;
self.completed_at = Some(Utc::now());
self.result_model = Some(result_model.clone());
info!(
"Fine-tuning job {} completed. Model: {}",
self.id, result_model
);
}
pub fn fail(&mut self, error: String) {
self.status = JobStatus::Failed;
self.completed_at = Some(Utc::now());
self.error_message = Some(error.clone());
warn!("Fine-tuning job {} failed: {}", self.id, error);
}
pub fn duration(&self) -> Option<chrono::Duration> {
match (self.started_at, self.completed_at) {
(Some(start), Some(end)) => Some(end - start),
_ => None,
}
}
pub fn progress(&self) -> f32 {
match (self.current_step, self.total_steps) {
(Some(current), Some(total)) if total > 0 => current as f32 / total as f32,
_ => match self.status {
JobStatus::Completed => 1.0,
JobStatus::Pending => 0.0,
_ => 0.5,
},
}
}
pub fn is_active(&self) -> bool {
matches!(
self.status,
JobStatus::Pending | JobStatus::PreparingData | JobStatus::Training
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationFeedback {
pub id: String,
pub session_id: String,
pub message_id: Option<String>,
pub rating: u8,
pub comment: Option<String>,
pub flagged_incorrect: bool,
pub flagged_harmful: bool,
pub issue_categories: Vec<FeedbackIssue>,
pub submitted_at: DateTime<Utc>,
pub user_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum FeedbackIssue {
Inaccurate,
Unhelpful,
TooLong,
TooShort,
OffTopic,
Confusing,
Biased,
Outdated,
MissingContext,
}
impl ConversationFeedback {
pub fn new(session_id: String, rating: u8) -> Self {
let rating = rating.clamp(1, 5);
Self {
id: Uuid::new_v4().to_string(),
session_id,
message_id: None,
rating,
comment: None,
flagged_incorrect: false,
flagged_harmful: false,
issue_categories: Vec::new(),
submitted_at: Utc::now(),
user_id: None,
}
}
pub fn rating_to_quality(&self) -> f32 {
(self.rating as f32 - 1.0) / 4.0 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FinetuningManagerConfig {
pub storage_dir: PathBuf,
pub min_examples_for_export: usize,
pub default_quality_threshold: f32,
pub auto_deduplicate: bool,
pub max_dataset_size: usize,
}
impl Default for FinetuningManagerConfig {
fn default() -> Self {
Self {
storage_dir: PathBuf::from("data/finetuning"),
min_examples_for_export: 10,
default_quality_threshold: 0.5,
auto_deduplicate: true,
max_dataset_size: 100_000,
}
}
}
pub struct FinetuningManager {
config: FinetuningManagerConfig,
datasets: HashMap<String, TrainingDataset>,
jobs: HashMap<String, FinetuningJob>,
feedback: Vec<ConversationFeedback>,
}
impl FinetuningManager {
pub fn new(config: FinetuningManagerConfig) -> FinetuningResult<Self> {
fs::create_dir_all(&config.storage_dir)?;
Ok(Self {
config,
datasets: HashMap::new(),
jobs: HashMap::new(),
feedback: Vec::new(),
})
}
pub fn create_dataset(
&mut self,
name: String,
description: String,
) -> FinetuningResult<String> {
let id = Uuid::new_v4().to_string();
let mut dataset = TrainingDataset::new(id.clone(), name);
dataset.description = description;
info!("Created fine-tuning dataset: {}", dataset.id);
self.datasets.insert(id.clone(), dataset);
Ok(id)
}
pub fn get_dataset(&self, id: &str) -> FinetuningResult<&TrainingDataset> {
self.datasets
.get(id)
.ok_or_else(|| FinetuningError::DatasetNotFound(id.to_string()))
}
pub fn get_dataset_mut(&mut self, id: &str) -> FinetuningResult<&mut TrainingDataset> {
self.datasets
.get_mut(id)
.ok_or_else(|| FinetuningError::DatasetNotFound(id.to_string()))
}
pub fn add_example(
&mut self,
dataset_id: &str,
example: TrainingExample,
) -> FinetuningResult<()> {
let max_size = self.config.max_dataset_size;
let auto_dedup = self.config.auto_deduplicate;
let dataset = self.get_dataset_mut(dataset_id)?;
if max_size > 0 && dataset.examples.len() >= max_size {
warn!(
"Dataset {} is at maximum size ({}), dropping oldest example",
dataset_id, max_size
);
dataset.examples.remove(0);
}
dataset.add_example(example);
if auto_dedup && dataset.examples.len() % 1000 == 0 {
dataset.deduplicate();
}
Ok(())
}
pub fn submit_feedback(&mut self, feedback: ConversationFeedback) -> FinetuningResult<String> {
let id = feedback.id.clone();
debug!(
"Received feedback for session {}: rating={}",
feedback.session_id, feedback.rating
);
self.feedback.push(feedback);
Ok(id)
}
pub fn get_session_feedback(&self, session_id: &str) -> Vec<&ConversationFeedback> {
self.feedback
.iter()
.filter(|f| f.session_id == session_id)
.collect()
}
pub fn create_job(
&mut self,
dataset_id: &str,
config: FinetuningJobConfig,
) -> FinetuningResult<String> {
let dataset = self.get_dataset(dataset_id)?;
let available = dataset
.examples_above_quality(config.quality_threshold)
.len();
if available < self.config.min_examples_for_export {
return Err(FinetuningError::InsufficientData {
needed: self.config.min_examples_for_export,
available,
});
}
let job = FinetuningJob::new(dataset_id.to_string(), config);
let job_id = job.id.clone();
info!(
"Created fine-tuning job {} for dataset {}",
job_id, dataset_id
);
self.jobs.insert(job_id.clone(), job);
Ok(job_id)
}
pub fn get_job(&self, job_id: &str) -> FinetuningResult<&FinetuningJob> {
self.jobs
.get(job_id)
.ok_or_else(|| FinetuningError::JobNotFound(job_id.to_string()))
}
pub fn get_job_mut(&mut self, job_id: &str) -> FinetuningResult<&mut FinetuningJob> {
self.jobs
.get_mut(job_id)
.ok_or_else(|| FinetuningError::JobNotFound(job_id.to_string()))
}
pub fn export_dataset(
&self,
dataset_id: &str,
output_path: &Path,
format: &ExportFormat,
quality_threshold: Option<f32>,
) -> FinetuningResult<usize> {
let dataset = self.get_dataset(dataset_id)?;
let threshold = quality_threshold.unwrap_or(self.config.default_quality_threshold);
let jsonl = dataset.export_jsonl(format, threshold)?;
let line_count = jsonl.lines().count();
fs::write(output_path, &jsonl)?;
info!(
"Exported {} examples from dataset {} to {:?}",
line_count, dataset_id, output_path
);
Ok(line_count)
}
pub fn statistics(&self) -> FinetuningManagerStats {
let total_examples: usize = self.datasets.values().map(|d| d.examples.len()).sum();
let active_jobs = self.jobs.values().filter(|j| j.is_active()).count();
let completed_jobs = self
.jobs
.values()
.filter(|j| j.status == JobStatus::Completed)
.count();
FinetuningManagerStats {
total_datasets: self.datasets.len(),
total_examples,
total_jobs: self.jobs.len(),
active_jobs,
completed_jobs,
failed_jobs: self
.jobs
.values()
.filter(|j| j.status == JobStatus::Failed)
.count(),
total_feedback: self.feedback.len(),
average_rating: if self.feedback.is_empty() {
0.0
} else {
self.feedback.iter().map(|f| f.rating as f32).sum::<f32>()
/ self.feedback.len() as f32
},
}
}
pub fn list_datasets(&self) -> Vec<&TrainingDataset> {
self.datasets.values().collect()
}
pub fn list_jobs(&self, status_filter: Option<&JobStatus>) -> Vec<&FinetuningJob> {
match status_filter {
Some(status) => self.jobs.values().filter(|j| &j.status == status).collect(),
None => self.jobs.values().collect(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FinetuningManagerStats {
pub total_datasets: usize,
pub total_examples: usize,
pub total_jobs: usize,
pub active_jobs: usize,
pub completed_jobs: usize,
pub failed_jobs: usize,
pub total_feedback: usize,
pub average_rating: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
fn make_example(human: &str, assistant: &str) -> TrainingExample {
TrainingExample::new(human.to_string(), assistant.to_string())
}
#[test]
fn test_training_example_quality_score() {
let mut example = make_example(
"How do I write a SPARQL query to find all classes?",
"To find all classes in a SPARQL endpoint, you can use the following query:\n\nSELECT ?class WHERE { ?class a owl:Class }",
);
example.compute_quality_score();
assert!(example.quality_score > 0.0);
assert!(example.quality_score <= 1.0);
}
#[test]
fn test_openai_format_export() {
let example = make_example("What is RDF?", "RDF (Resource Description Framework) is a standard model for data interchange on the Web.");
let json = example.to_openai_format();
assert!(json.get("messages").is_some());
let messages = json["messages"].as_array().expect("messages array");
assert!(!messages.is_empty());
}
#[test]
fn test_alpaca_format_export() {
let example = make_example("Explain SPARQL", "SPARQL is a query language for RDF data.");
let json = example.to_alpaca_format();
assert!(json.get("instruction").is_some());
assert!(json.get("output").is_some());
}
#[test]
fn test_dataset_add_and_deduplicate() {
let mut dataset = TrainingDataset::new("test-ds".to_string(), "Test Dataset".to_string());
dataset.add_example(make_example("What is RDF?", "RDF is..."));
dataset.add_example(make_example("What is SPARQL?", "SPARQL is..."));
dataset.add_example(make_example("What is RDF?", "RDF is..."));
assert_eq!(dataset.examples.len(), 3);
let removed = dataset.deduplicate();
assert_eq!(removed, 1);
assert_eq!(dataset.examples.len(), 2);
}
#[test]
fn test_dataset_quality_filtering() {
let mut dataset = TrainingDataset::new("filter-ds".to_string(), "Filter Test".to_string());
let mut low_quality = make_example("Hi", "Hi!");
low_quality.quality_score = 0.1;
dataset.examples.push(low_quality);
let mut high_quality = make_example(
"How does SPARQL federation work?",
"SPARQL federation allows querying multiple endpoints simultaneously using the SERVICE keyword.",
);
high_quality.quality_score = 0.9;
dataset.examples.push(high_quality);
let high = dataset.examples_above_quality(0.5);
assert_eq!(high.len(), 1);
}
#[test]
fn test_dataset_export_jsonl() {
let mut dataset = TrainingDataset::new("export-ds".to_string(), "Export Test".to_string());
let mut example = make_example(
"What is OWL?",
"OWL (Web Ontology Language) is a semantic web language for defining ontologies.",
);
example.quality_score = 0.8;
dataset.examples.push(example);
let jsonl = dataset
.export_jsonl(&ExportFormat::OpenAiChatJsonl, 0.5)
.expect("export");
assert!(!jsonl.is_empty());
let line = jsonl.lines().next().expect("at least one line");
serde_json::from_str::<serde_json::Value>(line).expect("valid JSON");
}
#[test]
fn test_finetuning_manager_create_and_add() {
let dir =
env::temp_dir().join(format!("oxirs_finetuning_test_{}", Uuid::new_v4().simple()));
let config = FinetuningManagerConfig {
storage_dir: dir.clone(),
min_examples_for_export: 1,
..Default::default()
};
let mut manager = FinetuningManager::new(config).expect("create manager");
let dataset_id = manager
.create_dataset("Test Dataset".to_string(), "A test dataset".to_string())
.expect("create dataset");
for i in 0..5 {
let mut example = make_example(
&format!("Question {}?", i),
&format!("Answer {} with more detail about semantic web and RDF.", i),
);
example.quality_score = 0.8;
manager
.add_example(&dataset_id, example)
.expect("add example");
}
let stats = manager.statistics();
assert_eq!(stats.total_datasets, 1);
assert_eq!(stats.total_examples, 5);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_create_finetuning_job() {
let dir = env::temp_dir().join(format!("oxirs_job_test_{}", Uuid::new_v4().simple()));
let config = FinetuningManagerConfig {
storage_dir: dir.clone(),
min_examples_for_export: 1,
..Default::default()
};
let mut manager = FinetuningManager::new(config).expect("create manager");
let dataset_id = manager
.create_dataset("Job Test".to_string(), "For job creation test".to_string())
.expect("create dataset");
for i in 0..3 {
let mut example = make_example(
&format!("SPARQL question {}?", i),
&format!(
"Detailed SPARQL answer {} explaining the query pattern and expected results.",
i
),
);
example.quality_score = 0.8;
manager
.add_example(&dataset_id, example)
.expect("add example");
}
let job_id = manager
.create_job(&dataset_id, FinetuningJobConfig::default())
.expect("create job");
let job = manager.get_job(&job_id).expect("get job");
assert_eq!(job.status, JobStatus::Pending);
assert!(!job.id.is_empty());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_job_lifecycle() {
let job_config = FinetuningJobConfig::default();
let mut job = FinetuningJob::new("dataset-1".to_string(), job_config);
assert_eq!(job.status, JobStatus::Pending);
assert!(job.is_active());
assert_eq!(job.progress(), 0.0);
job.start();
assert_eq!(job.status, JobStatus::Training);
job.record_step(10, 0.85, Some(0.90));
assert_eq!(job.training_loss.len(), 1);
job.complete("ft:gpt-3.5-turbo:oxirs".to_string());
assert_eq!(job.status, JobStatus::Completed);
assert!(!job.is_active());
assert_eq!(job.progress(), 1.0);
assert!(job.completed_at.is_some());
}
#[test]
fn test_feedback_collection() {
let dir = env::temp_dir().join(format!("oxirs_feedback_test_{}", Uuid::new_v4().simple()));
let config = FinetuningManagerConfig {
storage_dir: dir.clone(),
..Default::default()
};
let mut manager = FinetuningManager::new(config).expect("create manager");
let feedback = ConversationFeedback::new("session-abc".to_string(), 5);
manager.submit_feedback(feedback).expect("submit feedback");
let session_feedback = manager.get_session_feedback("session-abc");
assert_eq!(session_feedback.len(), 1);
assert_eq!(session_feedback[0].rating, 5);
let stats = manager.statistics();
assert_eq!(stats.total_feedback, 1);
assert_eq!(stats.average_rating, 5.0);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_export_to_file() {
let dir = env::temp_dir().join(format!("oxirs_export_test_{}", Uuid::new_v4().simple()));
let config = FinetuningManagerConfig {
storage_dir: dir.clone(),
min_examples_for_export: 1,
..Default::default()
};
let mut manager = FinetuningManager::new(config).expect("create manager");
let dataset_id = manager
.create_dataset("Export Test".to_string(), "For export test".to_string())
.expect("create dataset");
let mut example = make_example(
"What is triple store?",
"A triple store is a purpose-built database for storing and querying RDF triples.",
);
example.quality_score = 0.9;
manager.add_example(&dataset_id, example).expect("add");
let output_path = dir.join("training_data.jsonl");
let count = manager
.export_dataset(
&dataset_id,
&output_path,
&ExportFormat::OpenAiChatJsonl,
Some(0.5),
)
.expect("export");
assert_eq!(count, 1);
assert!(output_path.exists());
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn test_dataset_stats_computation() {
let mut dataset = TrainingDataset::new("stats-ds".to_string(), "Stats Test".to_string());
for i in 0..10 {
let mut example = make_example(
&format!("Question {} about semantic web?", i),
&format!("Detailed answer about semantic web topic {}.", i),
);
if i < 5 {
example.quality_score = 0.8;
example.human_reviewed = true;
} else {
example.quality_score = 0.4;
}
dataset.examples.push(example);
}
dataset.recompute_stats();
assert_eq!(dataset.stats.total_examples, 10);
assert_eq!(dataset.stats.high_quality_examples, 5);
assert_eq!(dataset.stats.human_reviewed_examples, 5);
assert!(dataset.stats.average_quality_score > 0.0);
}
}