1use super::pipeline::{Flux2SampleOutput, Flux2SampleParams, generate_to_rgb};
19use crate::runner::{Flux2Runner, Flux2RunnerBuilder};
20use anyhow::Result;
21use rlx_runtime::Device;
22use std::collections::HashMap;
23use std::path::PathBuf;
24use std::sync::{Arc, Mutex};
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct Flux2SessionKey {
29 pub weights: PathBuf,
30 pub device: Device,
31 pub config_path: Option<PathBuf>,
32 pub lora_path: Option<PathBuf>,
33 pub lora_scale_bits: u32,
34 pub nvfp4: Option<bool>,
35}
36
37#[derive(Clone)]
39pub struct Flux2Session {
40 inner: Arc<Flux2Runner>,
41}
42
43impl Flux2Session {
44 pub fn open(builder: Flux2RunnerBuilder) -> Result<Self> {
45 Ok(Self {
46 inner: Arc::new(builder.build()?),
47 })
48 }
49
50 pub fn runner(&self) -> &Flux2Runner {
51 &self.inner
52 }
53
54 pub fn sample(&self, params: &Flux2SampleParams<'_>) -> Result<Flux2SampleOutput> {
55 super::pipeline::sample_rectified_flow(&self.inner, params)
56 }
57
58 pub fn generate_rgb(&self, params: &Flux2SampleParams<'_>) -> Result<(Vec<u8>, u32, u32)> {
59 generate_to_rgb(&self.inner, params)
60 }
61}
62
63#[derive(Default)]
65pub struct Flux2SessionCache {
66 sessions: Mutex<HashMap<Flux2SessionKey, Arc<Flux2Runner>>>,
67}
68
69impl Flux2SessionCache {
70 pub fn global() -> &'static Flux2SessionCache {
71 static CACHE: std::sync::OnceLock<Flux2SessionCache> = std::sync::OnceLock::new();
72 CACHE.get_or_init(Flux2SessionCache::default)
73 }
74
75 pub fn get_or_open(&self, builder: Flux2RunnerBuilder) -> Result<Flux2Session> {
76 let key = builder
77 .session_key()
78 .ok_or_else(|| anyhow::anyhow!("session cache requires .weights(...) on builder"))?;
79 let mut guard = self
80 .sessions
81 .lock()
82 .map_err(|e| anyhow::anyhow!("session cache lock poisoned: {e}"))?;
83 if let Some(r) = guard.get(&key) {
84 return Ok(Flux2Session {
85 inner: Arc::clone(r),
86 });
87 }
88 let runner = Arc::new(builder.build()?);
89 guard.insert(key, Arc::clone(&runner));
90 Ok(Flux2Session { inner: runner })
91 }
92
93 pub fn len(&self) -> Result<usize> {
94 Ok(self
95 .sessions
96 .lock()
97 .map_err(|e| anyhow::anyhow!("session cache lock poisoned: {e}"))?
98 .len())
99 }
100
101 pub fn is_empty(&self) -> Result<bool> {
102 Ok(self.len()? == 0)
103 }
104}