use std::collections::{BTreeMap, BTreeSet};
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use super::model::{Annotation, Category, Dataset, DatasetInfo, Image};
use super::{AnnotationId, BBoxXYXY, CategoryId, ImageId, Pixel};
use crate::error::PanlabelError;
const HEADER_COLUMNS: [&str; 6] = ["path", "x1", "y1", "x2", "y2", "class_name"];
#[derive(Debug)]
enum RetinanetRow {
Annotation {
path: String,
x1: f64,
y1: f64,
x2: f64,
y2: f64,
class_name: String,
},
Empty { path: String },
}
pub fn read_retinanet_csv(path: &Path) -> Result<Dataset, PanlabelError> {
let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
let file = File::open(path).map_err(PanlabelError::Io)?;
let reader = BufReader::new(file);
let rows = parse_csv_rows(reader, path)?;
retinanet_to_ir(rows, base_dir, path)
}
pub fn write_retinanet_csv(path: &Path, dataset: &Dataset) -> Result<(), PanlabelError> {
let csv_string = to_retinanet_csv_string(dataset)?;
let file = File::create(path).map_err(PanlabelError::Io)?;
let mut writer = BufWriter::new(file);
writer
.write_all(csv_string.as_bytes())
.map_err(PanlabelError::Io)?;
writer.flush().map_err(PanlabelError::Io)?;
Ok(())
}
pub fn from_retinanet_csv_str_with_base_dir(
csv: &str,
base_dir: &Path,
) -> Result<Dataset, PanlabelError> {
let dummy_path = base_dir.join("<string>");
let reader = std::io::Cursor::new(csv.as_bytes());
let rows = parse_csv_rows(reader, &dummy_path)?;
retinanet_to_ir(rows, base_dir, &dummy_path)
}
#[cfg(feature = "fuzzing")]
pub fn parse_retinanet_csv_slice(bytes: &[u8]) -> Result<(), csv::Error> {
let mut rdr = csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(bytes);
for result in rdr.records() {
let _record = result?;
}
Ok(())
}
pub fn to_retinanet_csv_string(dataset: &Dataset) -> Result<String, PanlabelError> {
let dummy_path = Path::new("<string>");
let category_lookup: BTreeMap<CategoryId, &Category> =
dataset.categories.iter().map(|cat| (cat.id, cat)).collect();
let mut anns_by_image: BTreeMap<ImageId, Vec<&Annotation>> = BTreeMap::new();
for ann in &dataset.annotations {
anns_by_image.entry(ann.image_id).or_default().push(ann);
}
let mut sorted_images: Vec<&Image> = dataset.images.iter().collect();
sorted_images.sort_by(|a, b| a.file_name.cmp(&b.file_name));
let mut csv_writer = csv::WriterBuilder::new()
.has_headers(false)
.from_writer(Vec::new());
for img in sorted_images {
match anns_by_image.get(&img.id) {
Some(anns) if !anns.is_empty() => {
let mut sorted_anns: Vec<&Annotation> = anns.clone();
sorted_anns.sort_by_key(|a| a.id);
for ann in sorted_anns {
let category = category_lookup.get(&ann.category_id).ok_or_else(|| {
PanlabelError::RetinanetCsvInvalid {
path: dummy_path.to_path_buf(),
message: format!(
"Annotation {} references non-existent category {}",
ann.id.as_u64(),
ann.category_id.as_u64()
),
}
})?;
csv_writer
.write_record([
&img.file_name,
&ann.bbox.xmin().to_string(),
&ann.bbox.ymin().to_string(),
&ann.bbox.xmax().to_string(),
&ann.bbox.ymax().to_string(),
&category.name,
])
.map_err(|source| PanlabelError::RetinanetCsvWrite {
path: dummy_path.to_path_buf(),
source,
})?;
}
}
_ => {
csv_writer
.write_record([&img.file_name, "", "", "", "", ""])
.map_err(|source| PanlabelError::RetinanetCsvWrite {
path: dummy_path.to_path_buf(),
source,
})?;
}
}
}
let bytes = csv_writer
.into_inner()
.map_err(|e| PanlabelError::Io(e.into_error()))?;
String::from_utf8(bytes).map_err(|e| PanlabelError::RetinanetCsvInvalid {
path: dummy_path.to_path_buf(),
message: format!("Invalid UTF-8 in output: {e}"),
})
}
fn parse_csv_rows<R: std::io::Read>(
reader: R,
source_path: &Path,
) -> Result<Vec<RetinanetRow>, PanlabelError> {
let mut csv_reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(reader);
let mut rows = Vec::new();
let mut is_first = true;
let mut row_num: usize = 0;
for result in csv_reader.records() {
row_num += 1;
let record = result.map_err(|source| PanlabelError::RetinanetCsvParse {
path: source_path.to_path_buf(),
source,
})?;
if record.len() != 6 {
return Err(PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: expected 6 columns, got {} in row: {:?}",
row_num,
record.len(),
record.iter().collect::<Vec<_>>()
),
});
}
let col0 = record.get(0).unwrap_or("");
let col1 = record.get(1).unwrap_or("");
let col2 = record.get(2).unwrap_or("");
let col3 = record.get(3).unwrap_or("");
let col4 = record.get(4).unwrap_or("");
let col5 = record.get(5).unwrap_or("");
if is_first {
is_first = false;
if col0 == HEADER_COLUMNS[0]
&& col1 == HEADER_COLUMNS[1]
&& col2 == HEADER_COLUMNS[2]
&& col3 == HEADER_COLUMNS[3]
&& col4 == HEADER_COLUMNS[4]
&& col5 == HEADER_COLUMNS[5]
{
continue;
}
}
if col0.is_empty() {
return Err(PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!("row {}: empty path field", row_num),
});
}
let bbox_fields = [col1, col2, col3, col4, col5];
let all_empty = bbox_fields.iter().all(|f| f.is_empty());
let all_present = bbox_fields.iter().all(|f| !f.is_empty());
if all_empty {
rows.push(RetinanetRow::Empty {
path: col0.to_string(),
});
} else if all_present {
let x1: f64 = col1
.parse()
.map_err(|_| PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: invalid x1 value '{}' for image '{}'",
row_num, col1, col0
),
})?;
let y1: f64 = col2
.parse()
.map_err(|_| PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: invalid y1 value '{}' for image '{}'",
row_num, col2, col0
),
})?;
let x2: f64 = col3
.parse()
.map_err(|_| PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: invalid x2 value '{}' for image '{}'",
row_num, col3, col0
),
})?;
let y2: f64 = col4
.parse()
.map_err(|_| PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: invalid y2 value '{}' for image '{}'",
row_num, col4, col0
),
})?;
rows.push(RetinanetRow::Annotation {
path: col0.to_string(),
x1,
y1,
x2,
y2,
class_name: col5.to_string(),
});
} else {
return Err(PanlabelError::RetinanetCsvInvalid {
path: source_path.to_path_buf(),
message: format!(
"row {}: partial annotation row for image '{}': some bbox/class fields are empty while others are present",
row_num, col0
),
});
}
}
Ok(rows)
}
fn retinanet_to_ir(
rows: Vec<RetinanetRow>,
base_dir: &Path,
source_path: &Path,
) -> Result<Dataset, PanlabelError> {
let mut image_paths: Vec<String> = Vec::new();
let mut seen_paths = std::collections::BTreeSet::new();
for row in &rows {
let p = match row {
RetinanetRow::Annotation { path, .. } => path,
RetinanetRow::Empty { path } => path,
};
if seen_paths.insert(p.clone()) {
image_paths.push(p.clone());
}
}
image_paths.sort();
let mut dim_cache: BTreeMap<String, (u32, u32)> = BTreeMap::new();
for img_path in &image_paths {
let dims = resolve_image_dimensions(base_dir, img_path, source_path)?;
dim_cache.insert(img_path.clone(), dims);
}
let image_map: BTreeMap<String, ImageId> = image_paths
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), ImageId::new((i + 1) as u64)))
.collect();
let images: Vec<Image> = image_paths
.iter()
.map(|p| {
let id = image_map[p];
let (width, height) = dim_cache[p];
Image::new(id, p.clone(), width, height)
})
.collect();
let category_names: BTreeSet<String> = rows
.iter()
.filter_map(|row| match row {
RetinanetRow::Annotation { class_name, .. } => Some(class_name.clone()),
RetinanetRow::Empty { .. } => None,
})
.collect();
let category_map: BTreeMap<String, CategoryId> = category_names
.iter()
.enumerate()
.map(|(i, name)| (name.clone(), CategoryId::new((i + 1) as u64)))
.collect();
let categories: Vec<Category> = category_names
.iter()
.map(|name| {
let id = category_map[name];
Category::new(id, name.clone())
})
.collect();
let mut annotations = Vec::new();
let mut ann_id_counter: u64 = 1;
for row in rows {
if let RetinanetRow::Annotation {
path,
x1,
y1,
x2,
y2,
class_name,
} = row
{
let image_id = image_map[&path];
let category_id = category_map[&class_name];
let bbox = BBoxXYXY::<Pixel>::from_xyxy(x1, y1, x2, y2);
annotations.push(Annotation::new(
AnnotationId::new(ann_id_counter),
image_id,
category_id,
bbox,
));
ann_id_counter += 1;
}
}
Ok(Dataset {
info: DatasetInfo::default(),
licenses: vec![],
images,
categories,
annotations,
})
}
fn resolve_image_dimensions(
base_dir: &Path,
image_ref: &str,
source_path: &Path,
) -> Result<(u32, u32), PanlabelError> {
let image_path = if Path::new(image_ref).is_absolute() {
PathBuf::from(image_ref)
} else {
base_dir.join(image_ref)
};
let size = imagesize::size(&image_path).map_err(|source| {
if !image_path.exists() {
return PanlabelError::RetinanetImageNotFound {
path: source_path.to_path_buf(),
image_ref: image_ref.to_string(),
};
}
PanlabelError::RetinanetImageDimensionRead {
path: image_path.clone(),
source,
}
})?;
Ok((size.width as u32, size.height as u32))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_annotated_images() {
let dataset = Dataset {
images: vec![
Image::new(1u64, "img_a.jpg", 640, 480),
Image::new(2u64, "img_b.jpg", 800, 600),
],
categories: vec![Category::new(1u64, "cat"), Category::new(2u64, "dog")],
annotations: vec![
Annotation::new(
1u64,
1u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(10.0, 20.0, 100.0, 200.0),
),
Annotation::new(
2u64,
1u64,
2u64,
BBoxXYXY::<Pixel>::from_xyxy(50.0, 60.0, 150.0, 250.0),
),
Annotation::new(
3u64,
2u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(5.0, 10.0, 55.0, 110.0),
),
],
..Default::default()
};
let csv = to_retinanet_csv_string(&dataset).expect("serialize failed");
let lines: Vec<&str> = csv.lines().collect();
assert_eq!(lines.len(), 3);
assert!(lines[0].starts_with("img_a.jpg,"));
assert!(lines[1].starts_with("img_a.jpg,"));
assert!(lines[2].starts_with("img_b.jpg,"));
assert!(lines[0].contains(",cat"));
assert!(lines[1].contains(",dog"));
}
#[test]
fn test_write_unannotated_image() {
let dataset = Dataset {
images: vec![Image::new(1u64, "empty.jpg", 100, 100)],
categories: vec![],
annotations: vec![],
..Default::default()
};
let csv = to_retinanet_csv_string(&dataset).expect("serialize failed");
let lines: Vec<&str> = csv.lines().collect();
assert_eq!(lines.len(), 1);
assert_eq!(lines[0], "empty.jpg,,,,,");
}
#[test]
fn test_write_mixed_annotated_and_unannotated() {
let dataset = Dataset {
images: vec![
Image::new(1u64, "annotated.jpg", 640, 480),
Image::new(2u64, "empty.jpg", 100, 100),
],
categories: vec![Category::new(1u64, "person")],
annotations: vec![Annotation::new(
1u64,
1u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(10.0, 20.0, 100.0, 200.0),
)],
..Default::default()
};
let csv = to_retinanet_csv_string(&dataset).expect("serialize failed");
let lines: Vec<&str> = csv.lines().collect();
assert_eq!(lines.len(), 2);
assert!(lines[0].starts_with("annotated.jpg,10"));
assert_eq!(lines[1], "empty.jpg,,,,,");
}
#[test]
fn test_write_deterministic_order() {
let dataset = Dataset {
images: vec![
Image::new(2u64, "z.jpg", 100, 100),
Image::new(1u64, "a.jpg", 100, 100),
],
categories: vec![Category::new(1u64, "obj")],
annotations: vec![
Annotation::new(
3u64,
2u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(0.0, 0.0, 10.0, 10.0),
),
Annotation::new(
1u64,
1u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(0.0, 0.0, 10.0, 10.0),
),
Annotation::new(
2u64,
1u64,
1u64,
BBoxXYXY::<Pixel>::from_xyxy(5.0, 5.0, 15.0, 15.0),
),
],
..Default::default()
};
let csv = to_retinanet_csv_string(&dataset).expect("serialize failed");
let lines: Vec<&str> = csv.lines().collect();
assert!(lines[0].starts_with("a.jpg,"));
assert!(lines[1].starts_with("a.jpg,"));
assert!(lines[2].starts_with("z.jpg,"));
}
#[test]
fn test_write_missing_category_error() {
let dataset = Dataset {
images: vec![Image::new(1u64, "test.jpg", 100, 100)],
categories: vec![], annotations: vec![Annotation::new(
1u64,
1u64,
1u64, BBoxXYXY::<Pixel>::from_xyxy(0.0, 0.0, 10.0, 10.0),
)],
..Default::default()
};
let result = to_retinanet_csv_string(&dataset);
assert!(result.is_err());
}
#[test]
fn test_parse_annotation_row() {
let csv = "img.jpg,10,20,100,200,person\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
assert_eq!(rows.len(), 1);
match &rows[0] {
RetinanetRow::Annotation {
path,
x1,
y1,
x2,
y2,
class_name,
} => {
assert_eq!(path, "img.jpg");
assert_eq!(*x1, 10.0);
assert_eq!(*y1, 20.0);
assert_eq!(*x2, 100.0);
assert_eq!(*y2, 200.0);
assert_eq!(class_name, "person");
}
RetinanetRow::Empty { .. } => panic!("expected annotation row"),
}
}
#[test]
fn test_parse_empty_row() {
let csv = "img.jpg,,,,,\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
assert_eq!(rows.len(), 1);
match &rows[0] {
RetinanetRow::Empty { path } => assert_eq!(path, "img.jpg"),
RetinanetRow::Annotation { .. } => panic!("expected empty row"),
}
}
#[test]
fn test_parse_header_skipped() {
let csv = "path,x1,y1,x2,y2,class_name\nimg.jpg,10,20,100,200,cat\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
assert_eq!(rows.len(), 1);
match &rows[0] {
RetinanetRow::Annotation { path, .. } => assert_eq!(path, "img.jpg"),
RetinanetRow::Empty { .. } => panic!("expected annotation row"),
}
}
#[test]
fn test_parse_no_header() {
let csv = "img.jpg,10,20,100,200,cat\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
assert_eq!(rows.len(), 1);
}
#[test]
fn test_parse_partial_row_rejected() {
let csv = "img.jpg,10,,100,200,cat\n";
let result = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"));
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("partial annotation row"));
}
#[test]
fn test_parse_empty_path_rejected() {
let csv = ",10,20,100,200,cat\n";
let result = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"));
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("empty path field"));
}
#[test]
fn test_parse_invalid_coordinate_rejected() {
let csv = "img.jpg,abc,20,100,200,cat\n";
let result = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"));
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("invalid x1 value"));
}
#[test]
fn test_parse_multiple_images() {
let csv = "a.jpg,10,20,100,200,cat\na.jpg,50,60,150,250,dog\nb.jpg,,,,,\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
assert_eq!(rows.len(), 3);
}
#[test]
fn test_parse_float_coordinates() {
let csv = "img.jpg,10.5,20.3,100.7,200.9,person\n";
let rows = parse_csv_rows(std::io::Cursor::new(csv.as_bytes()), Path::new("test.csv"))
.expect("parse failed");
match &rows[0] {
RetinanetRow::Annotation { x1, y1, x2, y2, .. } => {
assert!((x1 - 10.5).abs() < 1e-9);
assert!((y1 - 20.3).abs() < 1e-9);
assert!((x2 - 100.7).abs() < 1e-9);
assert!((y2 - 200.9).abs() < 1e-9);
}
_ => panic!("expected annotation row"),
}
}
}