use crate::config::OutputConfig;
use crate::{DataGeneratorError, Result};
use srdf::srdf_graph::SRDFGraph;
use srdf::{BuildRDF, NeighsRDF, RDFFormat};
use std::fs::File;
use std::path::PathBuf;
pub struct OutputWriter {
config: OutputConfig,
}
impl OutputWriter {
pub fn new(config: &OutputConfig) -> Result<Self> {
Ok(Self {
config: config.clone(),
})
}
pub async fn write_graph(&self, graph: &SRDFGraph) -> Result<()> {
self.write_graph_with_timing(graph, None).await
}
pub async fn write_graph_with_timing(
&self,
graph: &SRDFGraph,
generation_time: Option<std::time::Duration>,
) -> Result<()> {
if self.config.parallel_writing {
self.write_graph_parallel(graph, generation_time).await
} else {
self.write_graph_sequential(graph, generation_time).await
}
}
async fn write_graph_sequential(
&self,
graph: &SRDFGraph,
generation_time: Option<std::time::Duration>,
) -> Result<()> {
let format = self.get_rdf_format();
if let Some(parent) = self.config.path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut file = File::create(&self.config.path)?;
graph.serialize(&format, &mut file).map_err(|e| {
DataGeneratorError::OutputWriting(format!("Failed to serialize graph: {e}"))
})?;
tracing::info!("Graph written to: {}", self.config.path.display());
if self.config.write_stats {
self.write_statistics(graph, generation_time).await?;
}
if self.config.compress {
self.compress_output().await?;
}
Ok(())
}
async fn write_graph_parallel(
&self,
graph: &SRDFGraph,
generation_time: Option<std::time::Duration>,
) -> Result<()> {
let start_time = std::time::Instant::now();
if let Some(parent) = self.config.path.parent() {
std::fs::create_dir_all(parent)?;
}
let all_triples = graph
.triples()
.map_err(|e| {
DataGeneratorError::OutputWriting(format!("Failed to collect triples: {e}"))
})?
.collect::<Vec<_>>();
let total_triples = all_triples.len();
let optimal_file_count = self.config.get_optimal_file_count(total_triples);
let chunk_size = total_triples.div_ceil(optimal_file_count);
if chunk_size == 0 {
tracing::warn!("No triples to write");
return Ok(());
}
let triple_chunks: Vec<_> = all_triples.chunks(chunk_size).collect();
tracing::info!(
"Writing {} triples in {} parallel files ({} triples per file)",
total_triples,
triple_chunks.len(),
chunk_size
);
let file_tasks: Vec<_> = triple_chunks
.into_iter()
.enumerate()
.map(|(index, chunk)| {
let format = self.get_rdf_format();
let output_path = self.get_parallel_file_path(index);
let chunk_triples = chunk.to_vec();
tokio::spawn(async move {
Self::write_triple_chunk(chunk_triples, format, output_path).await
})
})
.collect();
let write_results = futures::future::try_join_all(file_tasks)
.await
.map_err(|e| {
DataGeneratorError::OutputWriting(format!("Parallel write task failed: {e}"))
})?;
for result in write_results {
result?;
}
let write_time = start_time.elapsed();
tracing::info!("Parallel writing completed in {:?}", write_time);
if self.config.write_stats {
self.write_statistics(graph, generation_time).await?;
}
if self.config.compress {
self.compress_parallel_output().await?;
}
self.create_parallel_manifest(optimal_file_count).await?;
Ok(())
}
async fn write_triple_chunk(
triples: Vec<oxrdf::Triple>,
format: RDFFormat,
output_path: PathBuf,
) -> Result<()> {
let mut chunk_graph = SRDFGraph::default();
for triple in triples {
chunk_graph
.add_triple(triple.subject, triple.predicate, triple.object)
.map_err(|e| {
DataGeneratorError::OutputWriting(format!("Failed to add triple to chunk: {e}"))
})?;
}
let mut file = File::create(&output_path)?;
chunk_graph.serialize(&format, &mut file).map_err(|e| {
DataGeneratorError::OutputWriting(format!("Failed to serialize chunk: {e}"))
})?;
tracing::debug!("Chunk written to: {}", output_path.display());
Ok(())
}
fn get_parallel_file_path(&self, index: usize) -> PathBuf {
let stem = self
.config
.path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("output");
let extension = self
.config
.path
.extension()
.and_then(|s| s.to_str())
.unwrap_or("ttl");
let parent = self
.config
.path
.parent()
.unwrap_or_else(|| std::path::Path::new("."));
parent.join(format!("{}_part_{:03}.{}", stem, index + 1, extension))
}
async fn create_parallel_manifest(&self, actual_file_count: usize) -> Result<()> {
let manifest_path = self.config.path.with_extension("manifest.txt");
let mut manifest_content = String::new();
manifest_content.push_str("# Data Generator Parallel Output Manifest\n");
manifest_content.push_str(&format!(
"# Generated on: {}\n",
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
));
manifest_content.push_str(&format!("# Total parallel files: {actual_file_count}\n\n"));
for i in 0..actual_file_count {
let file_path = self.get_parallel_file_path(i);
if file_path.exists() {
manifest_content.push_str(&format!("{}\n", file_path.display()));
}
}
tokio::fs::write(manifest_path, manifest_content)
.await
.map_err(|e| {
DataGeneratorError::OutputWriting(format!("Failed to write manifest: {e}"))
})?;
Ok(())
}
async fn compress_parallel_output(&self) -> Result<()> {
tracing::warn!("Parallel output compression not yet implemented");
Ok(())
}
async fn write_statistics(
&self,
graph: &SRDFGraph,
generation_time: Option<std::time::Duration>,
) -> Result<()> {
let stats_path = self.config.path.with_extension("stats.json");
let mut stats = GenerationStatistics::from_graph(graph);
if let Some(duration) = generation_time {
stats = stats.with_timing(duration);
}
let stats_json = serde_json::to_string_pretty(&stats)?;
std::fs::write(stats_path, stats_json)?;
Ok(())
}
async fn compress_output(&self) -> Result<()> {
tracing::warn!("Output compression not yet implemented");
Ok(())
}
fn get_rdf_format(&self) -> RDFFormat {
match self.config.format {
crate::config::OutputFormat::Turtle => RDFFormat::Turtle,
crate::config::OutputFormat::NTriples => RDFFormat::NTriples,
}
}
}
#[derive(Debug, serde::Serialize)]
pub struct GenerationStatistics {
pub total_triples: usize,
pub total_subjects: usize,
pub total_predicates: usize,
pub total_objects: usize,
pub generation_time: Option<String>,
pub shape_counts: std::collections::HashMap<String, usize>,
}
impl GenerationStatistics {
pub fn from_graph(graph: &SRDFGraph) -> Self {
use std::collections::HashSet;
let total_triples = graph.len();
let mut subjects = HashSet::new();
let mut predicates = HashSet::new();
let mut objects = HashSet::new();
let mut shape_counts = std::collections::HashMap::new();
if let Ok(triples) = graph.triples() {
for triple in triples {
subjects.insert(triple.subject.to_string());
let pred_str = triple.predicate.to_string();
predicates.insert(pred_str.clone());
objects.insert(triple.object.to_string());
if pred_str == "<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>" {
let shape_type = triple.object.to_string();
let shape_type = shape_type.trim_start_matches('<').trim_end_matches('>');
*shape_counts.entry(shape_type.to_string()).or_insert(0) += 1;
}
}
}
Self {
total_triples,
total_subjects: subjects.len(),
total_predicates: predicates.len(),
total_objects: objects.len(),
generation_time: None,
shape_counts,
}
}
pub fn with_timing(mut self, duration: std::time::Duration) -> Self {
self.generation_time = Some(format!("{}ms", duration.as_millis()));
self
}
}