use std::path::Path;
use image::{DynamicImage, GenericImageView as _};
use ort::{GraphOptimizationLevel, Session};
use self::{bits::Bits, image_processing::ModelImage};
mod bits;
mod image_processing;
mod model;
pub struct Trustmark {
encoder: Session,
decoder: Session,
version: Version,
variant: Variant,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("watermark is corrupt or missing")]
CorruptWatermark,
#[error("onnx error: {0}")]
Ort(#[from] ort::Error),
#[error("image processing error: {0}")]
ImageProcessing(#[from] image_processing::Error),
#[error("bits processing error: {0}")]
Bits(bits::Error),
#[error("invalid model variant")]
InvalidModelVariant,
}
impl From<bits::Error> for Error {
fn from(value: bits::Error) -> Self {
match value {
bits::Error::CorruptWatermark => Error::CorruptWatermark,
err => Error::Bits(err),
}
}
}
pub use bits::Version;
pub use model::Variant;
impl Trustmark {
pub fn new<P: AsRef<Path>>(
models: P,
variant: Variant,
version: Version,
) -> Result<Self, Error> {
let encoder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(8)?
.commit_from_file(models.as_ref().join(variant.encoder_filename()))?;
let decoder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(8)?
.commit_from_file(models.as_ref().join(variant.decoder_filename()))?;
Ok(Self {
encoder,
decoder,
version,
variant,
})
}
pub fn encode(
&self,
watermark: String,
img: DynamicImage,
strength: f32,
) -> Result<DynamicImage, Error> {
let (original_width, original_height) = img.dimensions();
let aspect_ratio = original_width as f32 / original_height as f32;
let encode_size = 256;
let input_img: ort::Value<ort::TensorValueType<f32>> =
ModelImage(encode_size, self.variant, img.clone()).try_into()?;
let bits: ort::Value<ort::TensorValueType<f32>> =
Bits::apply_error_correction_and_schema(watermark, self.version)?.into();
let outputs = self.encoder.run(ort::inputs![
"onnx::Concat_0" => input_img,
"onnx::Gemm_1" => bits,
]?)?;
let output_img = outputs["image"].try_extract_tensor::<f32>()?.to_owned();
let input_img: ort::Value<ort::TensorValueType<f32>> =
ModelImage(encode_size, self.variant, img.clone()).try_into()?;
let residual = (self.variant.strength_multiplier() * strength)
* (output_img - input_img.try_extract_tensor::<f32>()?);
let mut residual = residual.clamp(-0.2, 0.2);
if (self.variant == Variant::Q && !(0.5..=2.0).contains(&aspect_ratio))
|| self.variant == Variant::P
{
residual = image_processing::remove_boundary_artifact(
residual,
(original_width as usize, original_height as usize),
self.variant,
);
}
let ModelImage(_, _, residual) = (encode_size, self.variant, residual).try_into()?;
Ok(image_processing::apply_residual(img, residual))
}
pub fn decode(&self, img: DynamicImage) -> Result<String, Error> {
let decode_size = if self.variant == Variant::P { 224 } else { 256 };
let img: ort::Value<ort::TensorValueType<f32>> =
ModelImage(decode_size, self.variant, img).try_into()?;
let outputs = self.decoder.run(ort::inputs![
"image" => img,
]?)?;
let watermark = outputs["output"].try_extract_tensor::<f32>()?.to_owned();
let watermark: Bits = watermark.try_into()?;
Ok(watermark.get_data())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn loading_models() {
Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
}
fn roundtrip(path: impl AsRef<Path>) {
let tm = Trustmark::new("./models", Variant::Q, Version::Bch5).unwrap();
let input = image::open(path.as_ref()).unwrap();
let watermark = "1011011110011000111111000000011111011111011100000110110110111".to_owned();
let encoded = tm.encode(watermark.clone(), input, 0.95).unwrap();
encoded.to_rgba8().save("./test.png").unwrap();
let input = image::open("./test.png").unwrap();
let decoded = tm.decode(input).unwrap();
assert_eq!(watermark, decoded);
}
#[test]
fn roundtrip_ghost() {
roundtrip("../images/ghost.png");
}
#[test]
fn roundtrip_ufo() {
roundtrip("../images/ufo_240.jpg");
}
}