use image::{DynamicImage, GenericImageView};
use koharu_core::download;
use ort::{inputs, session::Session, value::TensorRef};
use std::cmp::{max, min};
#[derive(Debug)]
pub struct Lama {
model: Session,
}
impl Lama {
pub async fn new() -> anyhow::Result<Self> {
let model_path = download::hf_hub("mayocream/lama-manga-onnx", "lama-manga.onnx").await?;
let model = Session::builder()?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(model_path)?;
Ok(Lama { model })
}
pub fn inference(
&mut self,
image: &DynamicImage,
mask: &DynamicImage,
) -> anyhow::Result<DynamicImage> {
self.inference_tiled(image, mask, 512, 128)
}
pub fn inference_tiled(
&mut self,
image: &DynamicImage,
mask: &DynamicImage,
tile_size: u32,
overlap: u32,
) -> anyhow::Result<DynamicImage> {
let (w, h) = image.dimensions();
let tile = max(32, tile_size); let ovl = min(overlap, tile.saturating_sub(1));
let stride = tile.saturating_sub(ovl);
let mut acc_r = vec![0f32; (w * h) as usize];
let mut acc_g = vec![0f32; (w * h) as usize];
let mut acc_b = vec![0f32; (w * h) as usize];
let mut acc_w = vec![0f32; (w * h) as usize];
let img_rgb = image.to_rgb8();
let mask_luma = mask.to_luma8();
let mut y0 = 0u32;
while y0 < h {
let mut x0 = 0u32;
while x0 < w {
let x1 = min(x0 + tile, w);
let y1 = min(y0 + tile, h);
let eff_w = x1 - x0;
let eff_h = y1 - y0;
let mut any_masked = false;
'mask_check: for yy in 0..eff_h {
for xx in 0..eff_w {
if mask_luma.get_pixel(x0 + xx, y0 + yy)[0] > 0 {
any_masked = true;
break 'mask_check;
}
}
}
if !any_masked {
x0 = x0.saturating_add(stride).min(w);
continue;
}
let (tile_img, tile_mask) =
extract_reflect_padded_tile(&img_rgb, &mask_luma, x0, y0, eff_w, eff_h, tile);
let tile_out = self.infer_tile_512(&tile_img, &tile_mask)?;
let mut tile_out_crop = image::RgbImage::new(eff_w, eff_h);
for yy in 0..eff_h {
for xx in 0..eff_w {
tile_out_crop.put_pixel(xx, yy, *tile_out.get_pixel(xx, yy));
}
}
let weights = make_tile_weights(eff_w, eff_h, ovl);
for yy in 0..eff_h {
for xx in 0..eff_w {
let global_x = x0 + xx;
let global_y = y0 + yy;
let idx = (global_y * w + global_x) as usize;
let m = if mask_luma.get_pixel(global_x, global_y)[0] > 0 {
1.0f32
} else {
0.0f32
};
if m == 0.0 {
continue;
}
let wgt = weights[(yy * eff_w + xx) as usize] * m;
if wgt <= 0.0 {
continue;
}
let p = tile_out_crop.get_pixel(xx, yy);
acc_r[idx] += p[0] as f32 * wgt;
acc_g[idx] += p[1] as f32 * wgt;
acc_b[idx] += p[2] as f32 * wgt;
acc_w[idx] += wgt;
}
}
x0 = x0.saturating_add(stride).min(w);
}
y0 = y0.saturating_add(stride).min(h);
}
let mut out = img_rgb.clone();
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) as usize;
if mask_luma.get_pixel(x, y)[0] == 0 {
continue; }
let wsum = acc_w[idx];
if wsum > 0.0 {
let r = (acc_r[idx] / wsum).clamp(0.0, 255.0) as u8;
let g = (acc_g[idx] / wsum).clamp(0.0, 255.0) as u8;
let b = (acc_b[idx] / wsum).clamp(0.0, 255.0) as u8;
out.put_pixel(x, y, image::Rgb([r, g, b]));
}
}
}
Ok(DynamicImage::ImageRgb8(out))
}
}
fn extract_reflect_padded_tile(
img: &image::RgbImage,
mask: &image::GrayImage,
x0: u32,
y0: u32,
eff_w: u32,
eff_h: u32,
tile: u32,
) -> (image::RgbImage, image::GrayImage) {
let mut out_img = image::RgbImage::new(tile, tile);
let mut out_msk = image::GrayImage::new(tile, tile);
for yy in 0..eff_h {
for xx in 0..eff_w {
let src_x = x0 + xx;
let src_y = y0 + yy;
out_img.put_pixel(xx, yy, *img.get_pixel(src_x, src_y));
out_msk.put_pixel(xx, yy, *mask.get_pixel(src_x, src_y));
}
}
for yy in 0..eff_h {
for xx in eff_w..tile {
let rx = if eff_w == 0 {
0
} else {
eff_w - 1 - ((xx - eff_w) % eff_w)
};
let p = *out_img.get_pixel(rx, yy);
let m = *out_msk.get_pixel(rx, yy);
out_img.put_pixel(xx, yy, p);
out_msk.put_pixel(xx, yy, m);
}
}
for yy in eff_h..tile {
let sy = if eff_h == 0 {
0
} else {
eff_h - 1 - ((yy - eff_h) % eff_h)
};
for xx in 0..tile {
let p = *out_img.get_pixel(xx, sy);
let m = *out_msk.get_pixel(xx, sy);
out_img.put_pixel(xx, yy, p);
out_msk.put_pixel(xx, yy, m);
}
}
(out_img, out_msk)
}
fn make_tile_weights(w: u32, h: u32, overlap: u32) -> Vec<f32> {
use std::f32::consts::PI;
let mut weights = vec![1.0f32; (w * h) as usize];
let half = (overlap as f32) / 2.0;
if overlap == 0 {
return weights;
}
for y in 0..h {
for x in 0..w {
let dx = min(x, w - 1 - x) as f32;
let dy = min(y, h - 1 - y) as f32;
let d = dx.min(dy);
let wxy = if d >= half || half <= 1e-3 {
1.0
} else {
let t = (d / half).clamp(0.0, 1.0);
0.5 * (1.0 - (PI * (1.0 - t)).cos())
};
weights[(y * w + x) as usize] = wxy;
}
}
weights
}
impl Lama {
fn infer_tile_512(
&mut self,
tile_img: &image::RgbImage,
tile_mask: &image::GrayImage,
) -> anyhow::Result<image::RgbImage> {
let (tw, th) = tile_img.dimensions();
let w = tw as usize;
let h = th as usize;
let mut image_data = ndarray::Array::zeros((1, 3, h, w));
for y in 0..th {
for x in 0..tw {
let p = tile_img.get_pixel(x, y);
let fx = x as usize;
let fy = y as usize;
image_data[[0, 0, fy, fx]] = (p[0] as f32) / 255.0;
image_data[[0, 1, fy, fx]] = (p[1] as f32) / 255.0;
image_data[[0, 2, fy, fx]] = (p[2] as f32) / 255.0;
}
}
let mut mask_data = ndarray::Array::zeros((1, 1, h, w));
for y in 0..th {
for x in 0..tw {
let m = tile_mask.get_pixel(x, y)[0];
let fx = x as usize;
let fy = y as usize;
mask_data[[0, 0, fy, fx]] = if m > 0 { 1.0f32 } else { 0.0f32 };
}
}
let inputs = inputs![
"image" => TensorRef::from_array_view(image_data.view())?,
"mask" => TensorRef::from_array_view(mask_data.view())?,
];
let outputs = self.model.run(inputs)?;
let output = outputs["output"].try_extract_array::<f32>()?;
let output = output.view();
let mut out_img = image::RgbImage::new(tw, th);
for y in 0..th {
for x in 0..tw {
let r = (output[[0, 0, y as usize, x as usize]] * 255.0).clamp(0.0, 255.0) as u8;
let g = (output[[0, 1, y as usize, x as usize]] * 255.0).clamp(0.0, 255.0) as u8;
let b = (output[[0, 2, y as usize, x as usize]] * 255.0).clamp(0.0, 255.0) as u8;
out_img.put_pixel(x, y, image::Rgb([r, g, b]));
}
}
Ok(out_img)
}
}