Skip to main content

maolan_engine/workers/
hw_worker.rs

1use crate::{
2    hw::config,
3    hw::traits::{HwMidiHub, HwWorkerDriver},
4    message::{HwMidiEvent, Message},
5    mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::{Arc, Condvar, Mutex};
10use std::thread::JoinHandle;
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc::{Receiver, Sender};
13use tracing::error;
14
15pub trait Backend: Send + Sync + 'static {
16    type Driver: HwWorkerDriver + Send + 'static;
17    type MidiHub: HwMidiHub + Send + 'static;
18
19    const LABEL: &'static str;
20    const WORKER_THREAD_NAME: &'static str;
21    const ASSIST_THREAD_NAME: &'static str;
22    const ASSIST_AUTONOMOUS_ENV: &'static str;
23}
24
25#[derive(Debug)]
26pub struct HwWorker<B: Backend> {
27    driver: Arc<UnsafeMutex<B::Driver>>,
28    midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
29    rx: Receiver<Message>,
30    tx: Sender<Message>,
31    cycle_frames: u32,
32    pending_midi_out_events: Vec<HwMidiEvent>,
33    midi_in_events: Vec<HwMidiEvent>,
34    pending_midi_out_sorted: bool,
35}
36
37#[derive(Debug, Default)]
38struct AssistState {
39    shutdown: bool,
40    request_seq: u64,
41    done_seq: u64,
42    last_error: Option<String>,
43}
44
45#[cfg(unix)]
46const RT_POLICY: i32 = libc::SCHED_FIFO;
47const RT_PRIORITY_WORKER: i32 = 18;
48const RT_PRIORITY_ASSIST: i32 = 12;
49const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
50
51#[derive(Debug)]
52struct AssistProfiler {
53    report_at: Instant,
54    cycle_count: u64,
55    cycle_err_count: u64,
56    cycle_time_ns: u128,
57    step_count: u64,
58    step_work_count: u64,
59    step_err_count: u64,
60    step_time_ns: u128,
61    wait_count: u64,
62    wait_time_ns: u128,
63}
64
65impl AssistProfiler {
66    fn new() -> Self {
67        Self {
68            report_at: Instant::now() + PROFILE_INTERVAL,
69            cycle_count: 0,
70            cycle_err_count: 0,
71            cycle_time_ns: 0,
72            step_count: 0,
73            step_work_count: 0,
74            step_err_count: 0,
75            step_time_ns: 0,
76            wait_count: 0,
77            wait_time_ns: 0,
78        }
79    }
80
81    fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
82        let now = Instant::now();
83        if now < self.report_at {
84            return;
85        }
86        let cycle_avg_us = if self.cycle_count > 0 {
87            (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
88        } else {
89            0.0
90        };
91        let step_avg_us = if self.step_count > 0 {
92            (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
93        } else {
94            0.0
95        };
96        let wait_avg_us = if self.wait_count > 0 {
97            (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
98        } else {
99            0.0
100        };
101        let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
102            sample_rate as f64 / cycle_samples as f64
103        } else {
104            0.0
105        };
106        error!(
107            "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
108            label,
109            expected_cycles_per_sec,
110            self.cycle_count,
111            self.cycle_err_count,
112            cycle_avg_us,
113            self.step_count,
114            self.step_work_count,
115            self.step_err_count,
116            step_avg_us,
117            self.wait_count,
118            wait_avg_us
119        );
120        self.report_at = now + PROFILE_INTERVAL;
121        self.cycle_count = 0;
122        self.cycle_err_count = 0;
123        self.cycle_time_ns = 0;
124        self.step_count = 0;
125        self.step_work_count = 0;
126        self.step_err_count = 0;
127        self.step_time_ns = 0;
128        self.wait_count = 0;
129        self.wait_time_ns = 0;
130    }
131}
132
133impl<B: Backend> HwWorker<B> {
134    fn profile_enabled() -> bool {
135        config::env_flag(config::HW_PROFILE_ENV)
136    }
137
138    fn assist_autonomous_enabled() -> bool {
139        config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
140    }
141
142    fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
143        #[cfg(unix)]
144        {
145            let thread = unsafe { libc::pthread_self() };
146            #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
147            let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
148            #[cfg(target_os = "linux")]
149            unsafe {
150                let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
151            }
152            #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
153            unsafe {
154                libc::pthread_set_name_np(thread, c_name.as_ptr());
155            }
156
157            let param = unsafe {
158                let mut p = std::mem::zeroed::<libc::sched_param>();
159                p.sched_priority = priority;
160                p
161            };
162            let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, &param) };
163            if rc != 0 {
164                return Err(format!(
165                    "pthread_setschedparam({}, prio {}) failed with errno {}",
166                    name, priority, rc
167                ));
168            }
169
170            let mut actual_policy = 0_i32;
171            let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
172            let rc = unsafe {
173                libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
174            };
175            if rc != 0 {
176                return Err(format!(
177                    "pthread_getschedparam({}) failed with errno {}",
178                    name, rc
179                ));
180            }
181            if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
182                return Err(format!(
183                    "realtime verification failed for {}: policy {}, prio {}",
184                    name, actual_policy, actual_param.sched_priority
185                ));
186            }
187            Ok(())
188        }
189        #[cfg(not(unix))]
190        {
191            let _ = name;
192            let _ = priority;
193            Err("Realtime thread priority is not supported on this platform".to_string())
194        }
195    }
196
197    fn lock_memory_pages() -> Result<(), String> {
198        #[cfg(unix)]
199        {
200            let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
201            if rc == 0 {
202                Ok(())
203            } else {
204                Err(format!(
205                    "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
206                    std::io::Error::last_os_error()
207                ))
208            }
209        }
210        #[cfg(not(unix))]
211        {
212            Err("mlockall is not supported on this platform".to_string())
213        }
214    }
215
216    pub fn new(
217        driver: Arc<UnsafeMutex<B::Driver>>,
218        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
219        rx: Receiver<Message>,
220        tx: Sender<Message>,
221    ) -> Self {
222        let cycle_frames = {
223            let d = driver.lock();
224            d.cycle_samples() as u32
225        };
226        Self {
227            driver,
228            midi_hub,
229            rx,
230            tx,
231            cycle_frames,
232            pending_midi_out_events: vec![],
233            midi_in_events: Vec::with_capacity(64),
234            pending_midi_out_sorted: true,
235        }
236    }
237
238    pub async fn work(mut self) {
239        if let Err(e) = Self::lock_memory_pages() {
240            error!("{} worker memory lock not enabled: {}", B::LABEL, e);
241        }
242        if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
243            error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
244        }
245        #[cfg(target_os = "macos")]
246        unsafe {
247            libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
248        }
249        let assist_state = Arc::new((Mutex::new(AssistState::default()), Condvar::new()));
250        let assist_handle = Self::start_assist_thread(self.driver.clone(), assist_state.clone());
251        loop {
252            match self.rx.recv().await {
253                Some(msg) => match msg {
254                    Message::Request(crate::message::Action::Quit) => {
255                        if !self.pending_midi_out_events.is_empty() {
256                            if !self.pending_midi_out_sorted {
257                                self.pending_midi_out_events.sort_by(|a, b| {
258                                    a.event
259                                        .frame
260                                        .cmp(&b.event.frame)
261                                        .then_with(|| a.device.cmp(&b.device))
262                                });
263                                self.pending_midi_out_sorted = true;
264                            }
265                            let midi_hub = self.midi_hub.lock();
266                            midi_hub.write_events(&self.pending_midi_out_events);
267                            self.pending_midi_out_events.clear();
268                        }
269                        Self::stop_assist_thread(&assist_state, assist_handle);
270                        return;
271                    }
272                    Message::TracksFinished => {
273                        {
274                            let midi_hub = self.midi_hub.lock();
275                            midi_hub.read_events_into(&mut self.midi_in_events);
276                        }
277                        spread_hw_event_frames(&mut self.midi_in_events, self.cycle_frames);
278                        if !self.midi_in_events.is_empty() {
279                            let cap = self.midi_in_events.capacity();
280                            let out = std::mem::replace(
281                                &mut self.midi_in_events,
282                                Vec::with_capacity(cap.max(64)),
283                            );
284                            if let Err(e) = self.tx.send(Message::HWMidiEvents(out)).await {
285                                error!(
286                                    "{} worker failed to send HWMidiEvents to engine: {}",
287                                    B::LABEL,
288                                    e
289                                );
290                            }
291                        }
292                        {
293                            if !self.pending_midi_out_events.is_empty() {
294                                if !self.pending_midi_out_sorted {
295                                    self.pending_midi_out_events.sort_by(|a, b| {
296                                        a.event
297                                            .frame
298                                            .cmp(&b.event.frame)
299                                            .then_with(|| a.device.cmp(&b.device))
300                                    });
301                                    self.pending_midi_out_sorted = true;
302                                }
303                                let midi_hub = self.midi_hub.lock();
304                                midi_hub.write_events(&self.pending_midi_out_events);
305                                self.pending_midi_out_events.clear();
306                            }
307                        }
308                        if let Err(e) = Self::run_assist_cycle(&assist_state) {
309                            error!("{} assist cycle error: {}", B::LABEL, e);
310                        }
311                        if let Err(e) = self.tx.send(Message::HWFinished).await {
312                            error!(
313                                "{} worker failed to send HWFinished to engine: {}",
314                                B::LABEL,
315                                e
316                            );
317                        }
318                    }
319                    Message::HWMidiOutEvents(mut events) => {
320                        self.pending_midi_out_events.append(&mut events);
321                        self.pending_midi_out_sorted = false;
322                    }
323                    Message::ClearHWMidiOutEvents => {
324                        self.pending_midi_out_events.clear();
325                        self.pending_midi_out_sorted = true;
326                    }
327                    _ => {}
328                },
329                None => {
330                    Self::stop_assist_thread(&assist_state, assist_handle);
331                    return;
332                }
333            }
334        }
335    }
336
337    fn start_assist_thread(
338        driver: Arc<UnsafeMutex<B::Driver>>,
339        assist_state: Arc<(Mutex<AssistState>, Condvar)>,
340    ) -> JoinHandle<()> {
341        let profile = Self::profile_enabled();
342        let autonomous = Self::assist_autonomous_enabled();
343        std::thread::spawn(move || {
344            if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
345                error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
346            }
347            #[cfg(target_os = "macos")]
348            unsafe {
349                libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
350            }
351            let mut profiler = if profile {
352                let (cycle_samples, sample_rate) = {
353                    let d = driver.lock();
354                    (d.cycle_samples(), d.sample_rate())
355                };
356                error!(
357                    "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
358                    B::LABEL,
359                    cycle_samples,
360                    sample_rate,
361                    if cycle_samples > 0 {
362                        sample_rate as f64 / cycle_samples as f64
363                    } else {
364                        0.0
365                    }
366                );
367                Some(AssistProfiler::new())
368            } else {
369                None
370            };
371            let (lock, cvar) = &*assist_state;
372            loop {
373                let (shutdown, has_request, target) = {
374                    let st = lock.lock().expect("assist mutex poisoned");
375                    (st.shutdown, st.request_seq > st.done_seq, st.request_seq)
376                };
377                if shutdown {
378                    break;
379                }
380                if has_request {
381                    let started = Instant::now();
382                    let run_error = {
383                        let d = driver.lock();
384                        d.run_cycle_for_worker().err().map(|e| e.to_string())
385                    };
386                    if let Some(p) = profiler.as_mut() {
387                        p.cycle_count += 1;
388                        if run_error.is_some() {
389                            p.cycle_err_count += 1;
390                        }
391                        p.cycle_time_ns += started.elapsed().as_nanos();
392                        let (cycle_samples, sample_rate) = {
393                            let d = driver.lock();
394                            (d.cycle_samples(), d.sample_rate())
395                        };
396                        p.maybe_report(cycle_samples, sample_rate, B::LABEL);
397                    }
398                    let mut st = lock.lock().expect("assist mutex poisoned");
399                    st.done_seq = st.done_seq.max(target);
400                    st.last_error = run_error;
401                    cvar.notify_all();
402                    continue;
403                }
404
405                if !autonomous {
406                    let st = lock.lock().expect("assist mutex poisoned");
407                    if st.shutdown {
408                        break;
409                    }
410                    let wait_started = Instant::now();
411                    let _guard = cvar.wait(st).expect("assist condvar failed");
412                    if let Some(p) = profiler.as_mut() {
413                        p.wait_count += 1;
414                        p.wait_time_ns += wait_started.elapsed().as_nanos();
415                    }
416                    continue;
417                }
418
419                let started = Instant::now();
420                let did_work = {
421                    let d = driver.lock();
422                    match d.run_assist_step_for_worker() {
423                        Ok(v) => v,
424                        Err(e) => {
425                            if let Some(p) = profiler.as_mut() {
426                                p.step_err_count += 1;
427                            }
428                            let mut st = lock.lock().expect("assist mutex poisoned");
429                            st.last_error = Some(e.to_string());
430                            cvar.notify_all();
431                            false
432                        }
433                    }
434                };
435                if let Some(p) = profiler.as_mut() {
436                    p.step_count += 1;
437                    if did_work {
438                        p.step_work_count += 1;
439                    }
440                    p.step_time_ns += started.elapsed().as_nanos();
441                    let (cycle_samples, sample_rate) = {
442                        let d = driver.lock();
443                        (d.cycle_samples(), d.sample_rate())
444                    };
445                    p.maybe_report(cycle_samples, sample_rate, B::LABEL);
446                }
447                if !did_work {
448                    let st = lock.lock().expect("assist mutex poisoned");
449                    if st.shutdown {
450                        break;
451                    }
452                    let wait_started = Instant::now();
453                    let _guard = cvar.wait(st).expect("assist condvar failed");
454                    if let Some(p) = profiler.as_mut() {
455                        p.wait_count += 1;
456                        p.wait_time_ns += wait_started.elapsed().as_nanos();
457                    }
458                }
459            }
460        })
461    }
462
463    fn run_assist_cycle(assist_state: &Arc<(Mutex<AssistState>, Condvar)>) -> Result<(), String> {
464        let (lock, cvar) = &**assist_state;
465        let mut st = lock
466            .lock()
467            .map_err(|_| "assist mutex poisoned".to_string())?;
468        st.request_seq = st.request_seq.saturating_add(1);
469        let target = st.request_seq;
470        cvar.notify_one();
471        while st.done_seq < target && !st.shutdown {
472            st = cvar
473                .wait(st)
474                .map_err(|_| "assist condvar wait failed".to_string())?;
475        }
476        if let Some(err) = st.last_error.take() {
477            return Err(err);
478        }
479        Ok(())
480    }
481
482    fn stop_assist_thread(
483        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
484        assist_handle: JoinHandle<()>,
485    ) {
486        let (lock, cvar) = &**assist_state;
487        if let Ok(mut st) = lock.lock() {
488            st.shutdown = true;
489            cvar.notify_all();
490        }
491        let _ = assist_handle.join();
492    }
493}
494
495fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
496    if events.len() <= 1 || frames <= 1 {
497        return;
498    }
499    let n = events.len() as u32;
500    for (idx, event) in events.iter_mut().enumerate() {
501        let pos = idx as u32;
502        event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
503    }
504}