use crate::*;
pub struct Session {
examples: Vec<ImagePyramid>,
guides: Option<GuidesPyramidStruct>,
sampling_methods: Vec<SamplingMethod>,
generator: Generator,
params: Parameters,
}
impl Session {
pub fn builder<'a>() -> SessionBuilder<'a> {
SessionBuilder::default()
}
pub fn run(mut self, progress: Option<Box<dyn GeneratorProgress>>) -> GeneratedImage {
if let Some(count) = self.params.random_resolve {
let lvl = self.examples[0].pyramid.len();
let imgs: Vec<_> = self
.examples
.iter()
.map(|a| ImageBuffer::from(&a.pyramid[lvl - 1])) .collect();
self.generator
.resolve_random_batch(count as usize, &imgs, self.params.seed);
}
self.generator.resolve(
&self.params.to_generator_params(),
&self.examples,
progress,
&self.guides,
&self.sampling_methods,
);
GeneratedImage {
inner: self.generator,
}
}
}
#[derive(Default)]
pub struct SessionBuilder<'a> {
examples: Vec<Example<'a>>,
target_guide: Option<ImageSource<'a>>,
inpaint_mask: Option<InpaintMask<'a>>,
params: Parameters,
}
impl<'a> SessionBuilder<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn add_example<E: Into<Example<'a>>>(mut self, example: E) -> Self {
self.examples.push(example.into());
self
}
pub fn add_examples<E: Into<Example<'a>>, I: IntoIterator<Item = E>>(
mut self,
examples: I,
) -> Self {
self.examples.extend(examples.into_iter().map(|e| e.into()));
self
}
pub fn inpaint_example<I: Into<ImageSource<'a>>, E: Into<Example<'a>>>(
mut self,
inpaint_mask: I,
example: E,
size: Dims,
) -> Self {
self.inpaint_mask = Some(InpaintMask {
src: MaskOrImg::ImageSource(inpaint_mask.into()),
example_index: self.examples.len(),
dims: size,
});
self.examples.push(example.into());
self
}
pub fn inpaint_example_channel<E: Into<Example<'a>>>(
mut self,
mask: utils::ChannelMask,
example: E,
size: Dims,
) -> Self {
self.inpaint_mask = Some(InpaintMask {
src: MaskOrImg::Mask(mask),
example_index: self.examples.len(),
dims: size,
});
self.examples.push(example.into());
self
}
pub fn load_target_guide<I: Into<ImageSource<'a>>>(mut self, guide: I) -> Self {
self.target_guide = Some(guide.into());
self
}
pub fn resize_input(mut self, dims: Dims) -> Self {
self.params.resize_input = Some(dims);
self
}
pub fn seed(mut self, value: u64) -> Self {
self.params.seed = value;
self
}
pub fn tiling_mode(mut self, is_tiling: bool) -> Self {
self.params.tiling_mode = is_tiling;
self
}
pub fn nearest_neighbors(mut self, count: u32) -> Self {
self.params.nearest_neighbors = count;
self
}
pub fn random_sample_locations(mut self, count: u64) -> Self {
self.params.random_sample_locations = count;
self
}
pub fn random_init(mut self, count: u64) -> Self {
self.params.random_resolve = Some(count);
self
}
pub fn cauchy_dispersion(mut self, value: f32) -> Self {
self.params.cauchy_dispersion = value;
self
}
pub fn guide_alpha(mut self, value: f32) -> Self {
self.params.guide_alpha = value;
self
}
pub fn backtrack_percent(mut self, value: f32) -> Self {
self.params.backtrack_percent = value;
self
}
pub fn backtrack_stages(mut self, stages: u32) -> Self {
self.params.backtrack_stages = stages;
self
}
pub fn output_size(mut self, dims: Dims) -> Self {
self.params.output_size = dims;
self
}
pub fn max_thread_count(mut self, count: usize) -> Self {
self.params.max_thread_count = Some(count);
self
}
pub fn build(mut self) -> Result<Session, Error> {
self.check_parameters_validity()?;
self.check_images_validity()?;
struct InpaintExample {
inpaint_mask: image::RgbaImage,
color_map: image::RgbaImage,
example_index: usize,
}
let (inpaint, out_size, in_size) = match self.inpaint_mask {
Some(inpaint_mask) => {
let dims = inpaint_mask.dims;
let inpaint_img = match inpaint_mask.src {
MaskOrImg::ImageSource(img) => load_image(img, Some(dims))?,
MaskOrImg::Mask(mask) => {
let example_img = &mut self.examples[inpaint_mask.example_index].img;
let dynamic_img = utils::load_dynamic_image(example_img.clone())?;
let inpaint_src = ImageSource::Image(dynamic_img.clone());
*example_img = ImageSource::Image(dynamic_img);
let inpaint_mask = load_image(inpaint_src, Some(dims))?;
utils::apply_mask(inpaint_mask, mask)
}
};
let color_map = load_image(
self.examples[inpaint_mask.example_index].img.clone(),
Some(dims),
)?;
(
Some(InpaintExample {
inpaint_mask: inpaint_img,
color_map,
example_index: inpaint_mask.example_index,
}),
dims,
Some(dims),
)
}
None => (None, self.params.output_size, self.params.resize_input),
};
let target_guide = match self.target_guide {
Some(tg) => {
let tg_img = load_image(tg, Some(out_size))?;
let num_guides = self.examples.iter().filter(|ex| ex.guide.is_some()).count();
let tg_img = if num_guides == 0 {
transform_to_guide_map(tg_img, None, 2.0)
} else {
tg_img
};
Some(ImagePyramid::new(
tg_img,
Some(self.params.backtrack_stages as u32),
))
}
None => None,
};
let example_len = self.examples.len();
let mut examples = Vec::with_capacity(example_len);
let mut guides = if target_guide.is_some() {
Vec::with_capacity(example_len)
} else {
Vec::new()
};
let mut methods = Vec::with_capacity(example_len);
for example in self.examples {
let resolved = example.resolve(self.params.backtrack_stages, in_size, &target_guide)?;
examples.push(resolved.image);
if let Some(guide) = resolved.guide {
guides.push(guide);
}
methods.push(resolved.method);
}
let generator = match inpaint {
None => Generator::new(out_size),
Some(inpaint) => Generator::new_from_inpaint(
out_size,
inpaint.inpaint_mask,
inpaint.color_map,
inpaint.example_index,
),
};
let session = Session {
examples,
guides: target_guide.map(|tg| GuidesPyramidStruct {
target_guide: tg,
example_guides: guides,
}),
sampling_methods: methods,
params: self.params,
generator,
};
Ok(session)
}
fn check_parameters_validity(&self) -> Result<(), Error> {
if self.params.cauchy_dispersion < 0.0 || self.params.cauchy_dispersion > 1.0 {
return Err(Error::InvalidRange(errors::InvalidRange {
min: 0.0,
max: 1.0,
value: self.params.cauchy_dispersion,
name: "cauchy-dispersion",
}));
}
if self.params.backtrack_percent < 0.0 || self.params.backtrack_percent > 1.0 {
return Err(Error::InvalidRange(errors::InvalidRange {
min: 0.0,
max: 1.0,
value: self.params.backtrack_percent,
name: "backtrack-percent",
}));
}
if self.params.guide_alpha < 0.0 || self.params.guide_alpha > 1.0 {
return Err(Error::InvalidRange(errors::InvalidRange {
min: 0.0,
max: 1.0,
value: self.params.guide_alpha,
name: "guide-alpha",
}));
}
if let Some(max_count) = self.params.max_thread_count {
if max_count == 0 {
return Err(Error::InvalidRange(errors::InvalidRange {
min: 1.0,
max: 1024.0,
value: max_count as f32,
name: "max-thread-count",
}));
}
}
if self.params.random_sample_locations == 0 {
return Err(Error::InvalidRange(errors::InvalidRange {
min: 1.0,
max: 1024.0,
value: self.params.random_sample_locations as f32,
name: "m-rand",
}));
}
Ok(())
}
fn check_images_validity(&self) -> Result<(), Error> {
let input_count = self
.examples
.iter()
.filter(|ex| !ex.sample_method.is_ignore())
.count();
if input_count == 0 {
return Err(Error::NoExamples);
}
let num_guides = self.examples.iter().filter(|ex| ex.guide.is_some()).count();
if num_guides != 0 && self.examples.len() != num_guides {
return Err(Error::ExampleGuideMismatch(
self.examples.len() as u32,
num_guides as u32,
));
}
Ok(())
}
}
pub struct ProgressStat {
pub current: usize,
pub total: usize,
}
pub struct ProgressUpdate<'a> {
pub image: &'a image::RgbaImage,
pub total: ProgressStat,
pub stage: ProgressStat,
}
pub trait GeneratorProgress {
fn update(&mut self, info: ProgressUpdate<'_>);
}
impl<G> GeneratorProgress for G
where
G: FnMut(ProgressUpdate<'_>) + Send,
{
fn update(&mut self, info: ProgressUpdate<'_>) {
self(info);
}
}