use std::path::Path;
use anima_tagger_core::config::TaggerProfile;
use anima_tagger_core::hub;
use anima_tagger_core::sidecar::AutoTag;
use image::DynamicImage;
use image::imageops::FilterType;
use image::{Rgb, RgbImage};
use ndarray::Array4;
use ort::session::Session;
use ort::session::builder::GraphOptimizationLevel;
use ort::value::Tensor;
use serde::Deserialize;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TaggerError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("ort: {0}")]
Ort(String),
#[error("image: {0}")]
Image(#[from] image::ImageError),
#[error("csv parse on {path}: {source}")]
Csv {
path: std::path::PathBuf,
#[source]
source: csv::Error,
},
#[error("hub: {0}")]
Hub(#[from] hub::HubError),
#[error(
"model output count ({actual}) does not match tag dictionary size ({expected}) — model and tag CSV are mismatched"
)]
OutputMismatch { expected: usize, actual: usize },
}
impl<F> From<ort::Error<F>> for TaggerError {
fn from(e: ort::Error<F>) -> Self {
TaggerError::Ort(e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct TagDef {
pub name: String,
pub category: String,
}
#[derive(Debug, Deserialize)]
struct CsvRow {
#[serde(rename = "tag_id")]
_tag_id: i64,
name: String,
category: i32,
#[serde(default, rename = "count")]
_count: Option<i64>,
}
pub fn load_tags(path: &Path) -> Result<Vec<TagDef>, TaggerError> {
let mut rdr = csv::Reader::from_path(path).map_err(|source| TaggerError::Csv {
path: path.to_path_buf(),
source,
})?;
let mut out = Vec::new();
for row in rdr.deserialize::<CsvRow>() {
let row = row.map_err(|source| TaggerError::Csv {
path: path.to_path_buf(),
source,
})?;
out.push(TagDef {
name: row.name,
category: category_name(row.category),
});
}
Ok(out)
}
fn category_name(id: i32) -> String {
match id {
0 => "general".to_string(),
1 => "artist".to_string(),
3 => "copyright".to_string(),
4 => "character".to_string(),
5 => "meta".to_string(),
9 => "rating".to_string(),
other => format!("category_{other}"),
}
}
pub fn preprocess_wd14(img: &DynamicImage, size: u32) -> Array4<f32> {
let rgb = img.to_rgb8();
let (w, h) = rgb.dimensions();
let max_side = w.max(h);
let mut square = RgbImage::from_pixel(max_side, max_side, Rgb([255, 255, 255]));
let dx = ((max_side - w) / 2) as i64;
let dy = ((max_side - h) / 2) as i64;
image::imageops::overlay(&mut square, &rgb, dx, dy);
let resized = image::imageops::resize(&square, size, size, FilterType::CatmullRom);
let s = size as usize;
let mut arr = Array4::<f32>::zeros((1, s, s, 3));
for (x, y, p) in resized.enumerate_pixels() {
let [r, g, b] = p.0;
arr[(0, y as usize, x as usize, 0)] = b as f32;
arr[(0, y as usize, x as usize, 1)] = g as f32;
arr[(0, y as usize, x as usize, 2)] = r as f32;
}
arr
}
pub struct Tagger {
session: Session,
tags: Vec<TagDef>,
input_size: u32,
}
impl Tagger {
pub fn from_profile(profile: &TaggerProfile) -> Result<Self, TaggerError> {
let files = hub::fetch_files(
&profile.repo,
profile.revision.as_deref(),
&["model.onnx", "selected_tags.csv"],
)?;
let model_path = &files[0];
let tags_path = &files[1];
let bytes = std::fs::read(model_path)?;
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_memory(&bytes)?;
let tags = load_tags(tags_path)?;
Ok(Self {
session,
tags,
input_size: profile.input_size,
})
}
pub fn num_tags(&self) -> usize {
self.tags.len()
}
pub fn tag_image(
&mut self,
image_path: &Path,
threshold: f32,
) -> Result<Vec<AutoTag>, TaggerError> {
let img = image::open(image_path)?;
let arr = preprocess_wd14(&img, self.input_size);
let s = self.input_size as i64;
let shape = [1_i64, s, s, 3];
let data: Vec<f32> = arr.iter().copied().collect();
let tensor = Tensor::from_array((shape, data))?;
let outputs = self.session.run(ort::inputs![tensor])?;
let (_out_shape, out_data) = outputs[0].try_extract_tensor::<f32>()?;
let total = out_data.len();
if total != self.tags.len() {
return Err(TaggerError::OutputMismatch {
expected: self.tags.len(),
actual: total,
});
}
let mut results: Vec<AutoTag> = Vec::new();
for (i, &score) in out_data.iter().enumerate() {
if score < threshold {
continue;
}
let tag = &self.tags[i];
results.push(AutoTag {
tag: tag.name.clone(),
score,
category: tag.category.clone(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
}