use crate::graph::graph::Graph;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::hash::Hash;
use std::io::{self, BufRead, BufReader, Write};
pub trait CsvIO<W, N, E> {
fn save_to_csv(&self, nodes_file: &str, edges_file: &str) -> io::Result<()>
where
W: Copy + PartialEq + std::fmt::Display,
N: Clone + Eq + Hash + std::fmt::Debug + std::fmt::Display,
E: Clone + std::fmt::Debug + std::fmt::Display;
fn load_from_csv(nodes_file: &str, edges_file: &str, directed: bool) -> io::Result<Self>
where
Self: Sized,
W: Copy + PartialEq + Default + std::str::FromStr,
N: Clone + Eq + Hash + std::fmt::Debug + std::str::FromStr,
E: Clone + std::fmt::Debug + Default + std::str::FromStr,
<W as std::str::FromStr>::Err: std::fmt::Debug,
<N as std::str::FromStr>::Err: std::fmt::Debug,
<E as std::str::FromStr>::Err: std::fmt::Debug;
}
fn escape_csv_field(value: &str) -> String {
if value.contains([',', '"', '\n', '\r']) {
format!("\"{}\"", value.replace('"', "\"\""))
} else {
value.to_string()
}
}
fn split_csv_line(line: &str) -> Vec<String> {
let mut fields = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut chars = line.chars().peekable();
while let Some(c) = chars.next() {
if in_quotes {
if c == '"' {
if chars.peek() == Some(&'"') {
current.push('"');
chars.next();
} else {
in_quotes = false;
}
} else {
current.push(c);
}
} else {
match c {
'"' => in_quotes = true,
',' => fields.push(std::mem::take(&mut current)),
_ => current.push(c),
}
}
}
fields.push(current);
fields
}
impl<W, N, E> CsvIO<W, N, E> for Graph<W, N, E>
where
W: Copy + PartialEq,
N: Clone + Eq + Hash + std::fmt::Debug,
E: Clone + std::fmt::Debug + Default,
{
fn save_to_csv(&self, nodes_file: &str, edges_file: &str) -> io::Result<()>
where
W: std::fmt::Display,
N: std::fmt::Display,
E: std::fmt::Display,
{
let mut nodes_writer = File::create(nodes_file)?;
let mut node_attrs: Vec<String> = self
.nodes
.iter()
.flat_map(|(_, node)| node.attributes.keys())
.collect::<HashSet<_>>()
.into_iter()
.cloned()
.collect();
node_attrs.sort();
let mut header = vec!["node_id".to_string(), "data".to_string()];
header.extend(node_attrs.iter().map(|k| escape_csv_field(k)));
writeln!(nodes_writer, "{}", header.join(","))?;
for (id, node) in self.nodes.iter() {
let mut row = vec![id.to_string(), escape_csv_field(&node.data.to_string())];
row.extend(node_attrs.iter().map(|key| {
node.attributes
.get(key)
.map_or(String::new(), |v| escape_csv_field(&v.to_string()))
}));
writeln!(nodes_writer, "{}", row.join(","))?;
}
let mut edges_writer = File::create(edges_file)?;
let mut edge_attrs: Vec<String> = self
.edges
.iter()
.flat_map(|(_, edge)| edge.attributes.keys())
.collect::<HashSet<_>>()
.into_iter()
.cloned()
.collect();
edge_attrs.sort();
let mut header = vec![
"from".to_string(),
"to".to_string(),
"weight".to_string(),
"data".to_string(),
];
header.extend(edge_attrs.iter().map(|k| escape_csv_field(k)));
writeln!(edges_writer, "{}", header.join(","))?;
for (_, edge) in self.edges.iter() {
let mut row = vec![
edge.from.to_string(),
edge.to.to_string(),
edge.weight.to_string(),
escape_csv_field(&edge.data.to_string()),
];
row.extend(edge_attrs.iter().map(|key| {
edge.attributes
.get(key)
.map_or(String::new(), |v| escape_csv_field(&v.to_string()))
}));
writeln!(edges_writer, "{}", row.join(","))?;
}
Ok(())
}
fn load_from_csv(nodes_file: &str, edges_file: &str, directed: bool) -> io::Result<Self>
where
W: Default + std::str::FromStr,
N: std::str::FromStr,
E: std::str::FromStr,
<W as std::str::FromStr>::Err: std::fmt::Debug,
<N as std::str::FromStr>::Err: std::fmt::Debug,
<E as std::str::FromStr>::Err: std::fmt::Debug,
{
let mut graph = Graph::new(directed);
let nodes_reader = BufReader::new(File::open(nodes_file)?);
let mut lines = nodes_reader.lines();
let header = lines
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Empty nodes file"))??;
let attr_keys: Vec<String> = split_csv_line(&header).into_iter().skip(2).collect();
let mut id_map: HashMap<usize, usize> = HashMap::new();
for line in lines {
let line = line?;
if line.is_empty() {
continue;
}
let parts = split_csv_line(&line);
let original_id: usize = parts
.first()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "Node line is missing an ID")
})?
.parse()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Node ID parse error: {:?}", e),
)
})?;
let data = parts
.get(1)
.map(String::as_str)
.unwrap_or("")
.parse()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Node data parse error: {:?}", e),
)
})?;
let node = graph.add_node(data);
id_map.insert(original_id, node);
for (key, value) in attr_keys.iter().zip(parts.iter().skip(2)) {
if !value.is_empty() {
graph
.set_node_attribute(node, key.clone(), value.clone())
.map_err(|e| {
io::Error::other(format!("Failed to set node attribute: {:?}", e))
})?;
}
}
}
let edges_reader = BufReader::new(File::open(edges_file)?);
let mut lines = edges_reader.lines();
let header = lines
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Empty edges file"))??;
let attr_keys: Vec<String> = split_csv_line(&header).into_iter().skip(4).collect();
for line in lines {
let line = line?;
if line.is_empty() {
continue;
}
let parts = split_csv_line(&line);
if parts.len() < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Edge line has too few fields: {}", line),
));
}
let from_raw: usize = parts[0].parse().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Edge 'from' parse error: {:?}", e),
)
})?;
let to_raw: usize = parts[1].parse().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Edge 'to' parse error: {:?}", e),
)
})?;
let weight = parts[2].parse().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Weight parse error: {:?}", e),
)
})?;
let data = parts[3].parse().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Edge data parse error: {:?}", e),
)
})?;
let from = id_map.get(&from_raw).copied().unwrap_or(from_raw);
let to = id_map.get(&to_raw).copied().unwrap_or(to_raw);
graph.add_edge(from, to, weight, data).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Add edge error: {:?}", e),
)
})?;
for (key, value) in attr_keys.iter().zip(parts.iter().skip(4)) {
if !value.is_empty() {
graph
.set_edge_attribute(from, to, key.clone(), value.clone())
.map_err(|e| {
io::Error::other(format!("Failed to set edge attribute: {:?}", e))
})?;
}
}
}
Ok(graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csv_io() {
let mut graph = Graph::<u32, String, String>::new(false);
let n1 = graph.add_node("A".to_string());
let n2 = graph.add_node("B".to_string());
graph.add_edge(n1, n2, 1, "edge".to_string()).unwrap();
graph
.set_node_attribute(n1, "color".to_string(), "red".to_string())
.unwrap();
graph
.set_edge_attribute(n1, n2, "type".to_string(), "road".to_string())
.unwrap();
graph
.save_to_csv("test_io_nodes.csv", "test_io_edges.csv")
.unwrap();
let loaded_graph = Graph::<u32, String, String>::load_from_csv(
"test_io_nodes.csv",
"test_io_edges.csv",
false,
)
.unwrap();
assert_eq!(graph.nodes.len(), loaded_graph.nodes.len());
assert_eq!(graph.edges.len(), loaded_graph.edges.len());
assert_eq!(
loaded_graph.get_node_attribute(n1, "color"),
Some(&"red".to_string())
);
assert_eq!(
loaded_graph.get_edge_attribute(n1, n2, "type"),
Some(&"road".to_string())
);
let edges = loaded_graph.get_all_edges();
assert_eq!(edges[0].3, "edge".to_string());
}
#[test]
fn test_csv_io_quoted_values() {
let mut graph = Graph::<u32, String, String>::new(true);
let n1 = graph.add_node("Paris, France".to_string());
let n2 = graph.add_node("Berlin".to_string());
graph
.add_edge(n1, n2, 5, "rail, high-speed".to_string())
.unwrap();
graph
.save_to_csv("test_quote_nodes.csv", "test_quote_edges.csv")
.unwrap();
let loaded = Graph::<u32, String, String>::load_from_csv(
"test_quote_nodes.csv",
"test_quote_edges.csv",
true,
)
.unwrap();
assert_eq!(
loaded.get_node_attribute(n1, "missing"),
None,
"no spurious attributes should appear"
);
let nodes: Vec<_> = loaded.all_nodes().map(|(_, d)| d.clone()).collect();
assert!(nodes.contains(&"Paris, France".to_string()));
assert_eq!(loaded.get_all_edges()[0].3, "rail, high-speed".to_string());
}
#[test]
fn test_csv_io_remaps_sparse_ids() {
let mut graph = Graph::<u32, String, String>::new(true);
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
let c = graph.add_node("C".to_string());
graph.add_edge(a, c, 7, "ac".to_string()).unwrap();
graph.remove_node(b).unwrap();
graph
.save_to_csv("test_sparse_nodes.csv", "test_sparse_edges.csv")
.unwrap();
let loaded = Graph::<u32, String, String>::load_from_csv(
"test_sparse_nodes.csv",
"test_sparse_edges.csv",
true,
)
.unwrap();
assert_eq!(loaded.nodes.len(), 2);
assert_eq!(loaded.edges.len(), 1);
let (from, to, weight, data) = loaded.get_all_edges()[0].clone();
assert_eq!(weight, 7);
assert_eq!(data, "ac".to_string());
assert_eq!(
loaded.all_nodes().find(|(id, _)| *id == from).unwrap().1,
"A"
);
assert_eq!(loaded.all_nodes().find(|(id, _)| *id == to).unwrap().1, "C");
}
}