use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use crossbeam_channel::{Receiver, Sender};
use pixl_gen::{CandleSdxlGenerator, GenRequest, Generator};
use crate::cli::GenerateArgs;
use crate::tui::gallery::Entry;
pub enum GenCommand {
Generate {
prompt: String,
negative: String,
count: u32,
},
}
pub enum GenEvent {
Loading,
Download {
file: String,
done: u64,
total: u64,
},
Loaded {
model: String,
cached: bool,
lora: Option<(String, f32)>,
merged: bool,
},
BatchStarted {
start: usize,
count: u32,
},
ImageStarted {
index: usize,
},
Step {
step: usize,
steps: usize,
},
Preview {
index: usize,
image: image::RgbImage,
},
ImageReady {
index: usize,
entry: Entry,
},
ImageFailed {
index: usize,
error: String,
},
BatchDone,
Error(String),
}
pub struct Actor {
cmd: Sender<GenCommand>,
pub events: Receiver<GenEvent>,
cancel: Arc<AtomicBool>,
_handle: JoinHandle<()>,
}
impl Actor {
pub fn spawn(
args: GenerateArgs,
w: u32,
h: u32,
out_dir: std::path::PathBuf,
skip: Arc<Mutex<HashSet<usize>>>,
) -> Self {
let (cmd_tx, cmd_rx) = crossbeam_channel::unbounded::<GenCommand>();
let (evt_tx, evt_rx) = crossbeam_channel::unbounded::<GenEvent>();
let cancel = Arc::new(AtomicBool::new(false));
let handle = {
let cancel = cancel.clone();
std::thread::spawn(move || run(args, w, h, out_dir, cmd_rx, evt_tx, cancel, skip))
};
Self {
cmd: cmd_tx,
events: evt_rx,
cancel,
_handle: handle,
}
}
pub fn generate(&self, prompt: String, negative: String, count: u32) {
let _ = self.cmd.send(GenCommand::Generate {
prompt,
negative,
count,
});
}
pub fn cancel(&self) {
self.cancel.store(true, Ordering::Relaxed);
}
}
#[allow(clippy::too_many_arguments)]
fn run(
args: GenerateArgs,
w: u32,
h: u32,
out_dir: std::path::PathBuf,
cmd_rx: Receiver<GenCommand>,
evt_tx: Sender<GenEvent>,
cancel: Arc<AtomicBool>,
skip: Arc<Mutex<HashSet<usize>>>,
) {
let _ = evt_tx.send(GenEvent::Loading);
let (model, loras) = crate::model_and_loras(&args);
let prog: pixl_gen::ProgressFn = {
let evt = evt_tx.clone();
Box::new(move |p: pixl_gen::DownloadProgress| {
let _ = evt.send(GenEvent::Download {
file: p.file,
done: p.done,
total: p.total,
});
})
};
let (mut generator, report) = match CandleSdxlGenerator::load(model, w, h, &loras, Some(prog)) {
Ok(g) => g,
Err(e) => {
let _ = evt_tx.send(GenEvent::Error(format!("loading generator: {e}")));
return;
}
};
let cur = Arc::new(AtomicUsize::new(0));
{
let evt = evt_tx.clone();
generator.set_step_callback(Box::new(move |step, steps| {
let _ = evt.send(GenEvent::Step { step, steps });
}));
}
{
let evt = evt_tx.clone();
let cur = cur.clone();
generator.set_preview_callback(Box::new(move |image| {
let _ = evt.send(GenEvent::Preview {
index: cur.load(Ordering::Relaxed),
image,
});
}));
}
let _ = evt_tx.send(GenEvent::Loaded {
model: report.model.to_string(),
cached: report.weights_cached,
lora: report.lora.clone(),
merged: !matches!(report.merge, pixl_gen::MergeState::None),
});
let mut next_index = 0usize;
while let Ok(GenCommand::Generate {
prompt,
negative,
count,
}) = cmd_rx.recv()
{
cancel.store(false, Ordering::Relaxed);
let _ = evt_tx.send(GenEvent::BatchStarted {
start: next_index,
count,
});
let req = GenRequest {
prompt: prompt.clone(),
negative: negative.clone(),
params: crate::gen_params(&args),
};
let slug = crate::slugify(&prompt);
let is_skipped = |id: usize| skip.lock().map(|s| s.contains(&id)).unwrap_or(false);
for _ in 0..count {
if cancel.load(Ordering::Relaxed) {
break;
}
let id = next_index;
next_index += 1;
if is_skipped(id) {
continue; }
cur.store(id, Ordering::Relaxed);
let _ = evt_tx.send(GenEvent::ImageStarted { index: id });
match generator.generate(&req, id) {
Ok(gi) => {
if is_skipped(id) {
continue; }
match crate::pixelize_and_save(gi, id, &out_dir, &slug, &args) {
Ok(saved) => {
let _ = evt_tx.send(GenEvent::ImageReady {
index: id,
entry: Entry {
path: saved.path,
prompt: prompt.clone(),
seed: Some(saved.seed),
saved: false,
},
});
}
Err(e) => {
let _ = evt_tx.send(GenEvent::ImageFailed {
index: id,
error: e.to_string(),
});
}
}
}
Err(e) => {
let _ = evt_tx.send(GenEvent::ImageFailed {
index: id,
error: e.to_string(),
});
}
}
}
let _ = evt_tx.send(GenEvent::BatchDone);
}
}