Skip to main content

maolan_engine/workers/
worker.rs

1use crate::message::{
2    Action, Message, OfflineAutomationLane, OfflineAutomationPoint, OfflineAutomationTarget,
3    OfflineBounceWork,
4};
5#[cfg(unix)]
6use nix::libc;
7use tokio::sync::mpsc::{Receiver, Sender};
8use tracing::error;
9use wavers::write as write_wav;
10
11#[derive(Debug)]
12pub struct Worker {
13    id: usize,
14    rx: Receiver<Message>,
15    tx: Sender<Message>,
16}
17
18impl Worker {
19    fn automation_lane_value_at(points: &[OfflineAutomationPoint], sample: usize) -> Option<f32> {
20        if points.is_empty() {
21            return None;
22        }
23        if sample <= points[0].sample {
24            return Some(points[0].value.clamp(0.0, 1.0));
25        }
26        if sample >= points[points.len().saturating_sub(1)].sample {
27            return Some(points[points.len().saturating_sub(1)].value.clamp(0.0, 1.0));
28        }
29        for segment in points.windows(2) {
30            let left = &segment[0];
31            let right = &segment[1];
32            if sample < left.sample || sample > right.sample {
33                continue;
34            }
35            let span = right.sample.saturating_sub(left.sample).max(1) as f32;
36            let t = (sample.saturating_sub(left.sample) as f32 / span).clamp(0.0, 1.0);
37            return Some((left.value + (right.value - left.value) * t).clamp(0.0, 1.0));
38        }
39        None
40    }
41
42    fn apply_freeze_automation_at_sample(
43        track: &mut crate::track::Track,
44        sample: usize,
45        lanes: &[OfflineAutomationLane],
46    ) {
47        for lane in lanes {
48            if matches!(
49                lane.target,
50                OfflineAutomationTarget::Volume | OfflineAutomationTarget::Balance
51            ) {
52                continue;
53            }
54            let Some(value) = Self::automation_lane_value_at(&lane.points, sample) else {
55                continue;
56            };
57            match lane.target {
58                OfflineAutomationTarget::Volume | OfflineAutomationTarget::Balance => {}
59                OfflineAutomationTarget::Mute => {
60                    track.set_muted(value >= 0.5);
61                }
62                #[cfg(all(unix, not(target_os = "macos")))]
63                OfflineAutomationTarget::Lv2Parameter {
64                    instance_id,
65                    index,
66                    min,
67                    max,
68                } => {
69                    let lo = min.min(max);
70                    let hi = max.max(min);
71                    let param_value = (lo + value * (hi - lo)).clamp(lo, hi);
72                    let _ = track.set_lv2_control_value(instance_id, index, param_value);
73                }
74                OfflineAutomationTarget::Vst3Parameter {
75                    instance_id,
76                    param_id,
77                } => {
78                    let _ = track.set_vst3_parameter(instance_id, param_id, value.clamp(0.0, 1.0));
79                }
80                OfflineAutomationTarget::ClapParameter {
81                    instance_id,
82                    param_id,
83                    min,
84                    max,
85                } => {
86                    let lo = min.min(max);
87                    let hi = max.max(min);
88                    let param_value = (lo + value as f64 * (hi - lo)).clamp(lo, hi);
89                    let _ = track.set_clap_parameter_at(instance_id, param_id, param_value, 0);
90                }
91            }
92        }
93    }
94
95    fn prepare_track_for_freeze_render(track: &mut crate::track::Track) -> (f32, f32) {
96        let original_level = track.level();
97        let original_balance = track.balance;
98        track.set_level(0.0);
99        track.set_balance(0.0);
100        (original_level, original_balance)
101    }
102
103    fn restore_track_after_freeze_render(
104        track: &mut crate::track::Track,
105        original_level: f32,
106        original_balance: f32,
107    ) {
108        track.set_level(original_level);
109        track.set_balance(original_balance);
110    }
111
112    async fn process_offline_bounce(&self, job: OfflineBounceWork) {
113        let track_handle = job.state.lock().tracks.get(&job.track_name).cloned();
114        let Some(target_track) = track_handle else {
115            let _ = self
116                .tx
117                .send(Message::OfflineBounceFinished {
118                    result: Err(format!("Track not found: {}", job.track_name)),
119                })
120                .await;
121            return;
122        };
123        let (channels, block_size, sample_rate) = {
124            let t = target_track.lock();
125            let block_size = t
126                .audio
127                .outs
128                .first()
129                .map(|io| io.buffer.lock().len())
130                .or_else(|| t.audio.ins.first().map(|io| io.buffer.lock().len()))
131                .unwrap_or(0)
132                .max(1);
133            (
134                t.audio.outs.len().max(1),
135                block_size,
136                t.sample_rate.round().max(1.0) as i32,
137            )
138        };
139        let (original_level, original_balance) = {
140            let mut t = target_track.lock();
141            Self::prepare_track_for_freeze_render(&mut t)
142        };
143
144        let mut rendered = vec![0.0_f32; job.length_samples.saturating_mul(channels)];
145        let mut cursor = 0usize;
146        while cursor < job.length_samples {
147            if job.cancel.load(std::sync::atomic::Ordering::Relaxed) {
148                {
149                    let mut t = target_track.lock();
150                    Self::restore_track_after_freeze_render(
151                        &mut t,
152                        original_level,
153                        original_balance,
154                    );
155                }
156                let _ = self
157                    .tx
158                    .send(Message::OfflineBounceFinished {
159                        result: Ok(Action::TrackOfflineBounceCanceled {
160                            track_name: job.track_name.clone(),
161                        }),
162                    })
163                    .await;
164                let _ = self.tx.send(Message::Ready(self.id)).await;
165                return;
166            }
167
168            let step = (job.length_samples - cursor).min(block_size);
169            let tracks: Vec<_> = job.state.lock().tracks.values().cloned().collect();
170            for handle in &tracks {
171                let t = handle.lock();
172                t.audio.finished = false;
173                t.audio.processing = false;
174                t.set_transport_sample(job.start_sample.saturating_add(cursor));
175                t.set_loop_config(false, None);
176                t.set_transport_timing(job.tempo_bpm, job.tsig_num, job.tsig_denom);
177                t.set_clip_playback_enabled(true);
178                t.set_record_tap_enabled(false);
179            }
180
181            let mut remaining = tracks.len();
182            while remaining > 0 {
183                let mut progressed = false;
184                for handle in &tracks {
185                    let t = handle.lock();
186                    if t.audio.finished || t.audio.processing {
187                        continue;
188                    }
189                    if t.audio.ready() {
190                        if t.name == job.track_name {
191                            Self::apply_freeze_automation_at_sample(
192                                t,
193                                job.start_sample.saturating_add(cursor),
194                                &job.automation_lanes,
195                            );
196                        }
197                        t.audio.processing = true;
198                        t.process();
199                        t.audio.processing = false;
200                        progressed = true;
201                        remaining = remaining.saturating_sub(1);
202                    }
203                }
204                if !progressed {
205                    for handle in &tracks {
206                        let t = handle.lock();
207                        if t.audio.finished {
208                            continue;
209                        }
210                        if t.name == job.track_name {
211                            Self::apply_freeze_automation_at_sample(
212                                t,
213                                job.start_sample.saturating_add(cursor),
214                                &job.automation_lanes,
215                            );
216                        }
217                        t.audio.processing = true;
218                        t.process();
219                        t.audio.processing = false;
220                        remaining = remaining.saturating_sub(1);
221                    }
222                }
223            }
224
225            {
226                let t = target_track.lock();
227                for ch in 0..channels {
228                    let out = t.audio.outs[ch].buffer.lock();
229                    let copy_len = step.min(out.len());
230                    for i in 0..copy_len {
231                        let dst = (cursor + i) * channels + ch;
232                        rendered[dst] = out[i];
233                    }
234                }
235            }
236
237            cursor = cursor.saturating_add(step);
238            let _ = self
239                .tx
240                .send(Message::OfflineBounceFinished {
241                    result: Ok(Action::TrackOfflineBounceProgress {
242                        track_name: job.track_name.clone(),
243                        progress: (cursor as f32 / job.length_samples as f32).clamp(0.0, 1.0),
244                        operation: Some("Rendering freeze".to_string()),
245                    }),
246                })
247                .await;
248        }
249
250        if let Err(e) =
251            write_wav::<f32, _>(&job.output_path, &rendered, sample_rate, channels as u16)
252        {
253            {
254                let mut t = target_track.lock();
255                Self::restore_track_after_freeze_render(&mut t, original_level, original_balance);
256            }
257            let _ = self
258                .tx
259                .send(Message::OfflineBounceFinished {
260                    result: Err(format!(
261                        "Failed to write offline bounce '{}': {e}",
262                        job.output_path
263                    )),
264                })
265                .await;
266            let _ = self.tx.send(Message::Ready(self.id)).await;
267            return;
268        }
269
270        {
271            let mut t = target_track.lock();
272            Self::restore_track_after_freeze_render(&mut t, original_level, original_balance);
273        }
274
275        let _ = self
276            .tx
277            .send(Message::OfflineBounceFinished {
278                result: Ok(Action::TrackOfflineBounce {
279                    track_name: job.track_name,
280                    output_path: job.output_path,
281                    start_sample: job.start_sample,
282                    length_samples: job.length_samples,
283                    automation_lanes: vec![],
284                }),
285            })
286            .await;
287        let _ = self.tx.send(Message::Ready(self.id)).await;
288    }
289
290    #[cfg(unix)]
291    fn try_enable_realtime() -> Result<(), String> {
292        let thread = unsafe { libc::pthread_self() };
293        let policy = libc::SCHED_FIFO;
294        let param = unsafe {
295            let mut p = std::mem::zeroed::<libc::sched_param>();
296            p.sched_priority = 10;
297            p
298        };
299        let rc = unsafe { libc::pthread_setschedparam(thread, policy, &param) };
300        if rc == 0 {
301            Ok(())
302        } else {
303            Err(format!("pthread_setschedparam failed with errno {}", rc))
304        }
305    }
306
307    #[cfg(not(unix))]
308    fn try_enable_realtime() -> Result<(), String> {
309        Err("Realtime thread priority is not supported on this platform".to_string())
310    }
311
312    pub async fn new(id: usize, rx: Receiver<Message>, tx: Sender<Message>) -> Worker {
313        let worker = Worker { id, rx, tx };
314        worker.send(Message::Ready(id)).await;
315        worker
316    }
317
318    pub async fn send(&self, message: Message) {
319        self.tx
320            .send(message)
321            .await
322            .expect("Failed to send message from worker");
323    }
324
325    pub async fn work(&mut self) {
326        if let Err(e) = Self::try_enable_realtime() {
327            error!("Worker {} realtime priority not enabled: {}", self.id, e);
328        }
329        while let Some(message) = self.rx.recv().await {
330            match message {
331                Message::Request(Action::Quit) => {
332                    return;
333                }
334                Message::ProcessTrack(t) => {
335                    let (track_name, output_linear, process_epoch) = {
336                        let track = t.lock();
337                        let process_epoch = track.process_epoch;
338                        track.process();
339                        track.audio.processing = false;
340                        (
341                            track.name.clone(),
342                            track.output_meter_linear(),
343                            process_epoch,
344                        )
345                    };
346                    match self
347                        .tx
348                        .send(Message::Finished {
349                            worker_id: self.id,
350                            track_name,
351                            output_linear,
352                            process_epoch,
353                        })
354                        .await
355                    {
356                        Ok(_) => {}
357                        Err(e) => {
358                            error!("Error while sending Finished: {}", e);
359                        }
360                    }
361                }
362                Message::ProcessOfflineBounce(job) => {
363                    self.process_offline_bounce(job).await;
364                }
365                _ => {}
366            }
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::Worker;
374    use crate::message::{
375        Action, Message, OfflineAutomationLane, OfflineAutomationPoint, OfflineAutomationTarget,
376        OfflineBounceWork,
377    };
378    use crate::mutex::UnsafeMutex;
379    use crate::state::State;
380    use crate::track::Track;
381    use std::path::PathBuf;
382    use std::sync::{Arc, atomic::AtomicBool};
383    use std::time::{SystemTime, UNIX_EPOCH};
384    use tokio::sync::mpsc::channel;
385
386    fn make_state_with_track(track: Track) -> Arc<UnsafeMutex<State>> {
387        let mut state = State::default();
388        state.tracks.insert(
389            track.name.clone(),
390            Arc::new(UnsafeMutex::new(Box::new(track))),
391        );
392        Arc::new(UnsafeMutex::new(state))
393    }
394
395    fn unique_temp_wav(name: &str) -> PathBuf {
396        let nanos = SystemTime::now()
397            .duration_since(UNIX_EPOCH)
398            .expect("clock")
399            .as_nanos();
400        std::env::temp_dir().join(format!("maolan_{name}_{nanos}.wav"))
401    }
402
403    #[test]
404    fn prepare_track_for_freeze_render_neutralizes_level_and_balance() {
405        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 64, 48_000.0);
406        track.set_level(-6.0);
407        track.set_balance(0.35);
408
409        let (level, balance) = Worker::prepare_track_for_freeze_render(&mut track);
410
411        assert_eq!(level, -6.0);
412        assert_eq!(balance, 0.35);
413        assert_eq!(track.level(), 0.0);
414        assert_eq!(track.balance, 0.0);
415
416        Worker::restore_track_after_freeze_render(&mut track, level, balance);
417        assert_eq!(track.level(), -6.0);
418        assert_eq!(track.balance, 0.35);
419    }
420
421    #[test]
422    fn freeze_automation_ignores_volume_and_balance_lanes() {
423        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 64, 48_000.0);
424        let lanes = vec![
425            OfflineAutomationLane {
426                target: OfflineAutomationTarget::Volume,
427                points: vec![OfflineAutomationPoint {
428                    sample: 0,
429                    value: 0.0,
430                }],
431            },
432            OfflineAutomationLane {
433                target: OfflineAutomationTarget::Balance,
434                points: vec![OfflineAutomationPoint {
435                    sample: 0,
436                    value: 1.0,
437                }],
438            },
439            OfflineAutomationLane {
440                target: OfflineAutomationTarget::Mute,
441                points: vec![OfflineAutomationPoint {
442                    sample: 0,
443                    value: 1.0,
444                }],
445            },
446        ];
447
448        Worker::apply_freeze_automation_at_sample(&mut track, 0, &lanes);
449
450        assert_eq!(track.level(), 0.0);
451        assert_eq!(track.balance, 0.0);
452        assert!(track.muted);
453    }
454
455    #[test]
456    fn automation_lane_value_at_interpolates_between_points() {
457        let value = Worker::automation_lane_value_at(
458            &[
459                OfflineAutomationPoint {
460                    sample: 10,
461                    value: 0.25,
462                },
463                OfflineAutomationPoint {
464                    sample: 20,
465                    value: 0.75,
466                },
467            ],
468            15,
469        )
470        .expect("value");
471
472        assert!((value - 0.5).abs() < 1.0e-6);
473    }
474
475    #[test]
476    fn freeze_automation_applies_interpolated_mute_lane() {
477        let mut track = Track::new("track".to_string(), 1, 1, 0, 0, 64, 48_000.0);
478        let lanes = vec![OfflineAutomationLane {
479            target: OfflineAutomationTarget::Mute,
480            points: vec![
481                OfflineAutomationPoint {
482                    sample: 0,
483                    value: 0.0,
484                },
485                OfflineAutomationPoint {
486                    sample: 10,
487                    value: 1.0,
488                },
489            ],
490        }];
491
492        Worker::apply_freeze_automation_at_sample(&mut track, 5, &lanes);
493        assert!(track.muted);
494
495        track.set_muted(false);
496        Worker::apply_freeze_automation_at_sample(&mut track, 2, &lanes);
497        assert!(!track.muted);
498    }
499
500    #[tokio::test]
501    async fn process_offline_bounce_errors_when_track_is_missing() {
502        let (_rx_unused_tx, rx_unused) = channel(1);
503        let (tx, mut out_rx) = channel(8);
504        let worker = Worker {
505            id: 7,
506            rx: rx_unused,
507            tx,
508        };
509        let job = OfflineBounceWork {
510            state: Arc::new(UnsafeMutex::new(State::default())),
511            track_name: "missing".to_string(),
512            output_path: unique_temp_wav("missing").to_string_lossy().to_string(),
513            start_sample: 0,
514            length_samples: 8,
515            tempo_bpm: 120.0,
516            tsig_num: 4,
517            tsig_denom: 4,
518            automation_lanes: vec![],
519            cancel: Arc::new(AtomicBool::new(false)),
520        };
521
522        worker.process_offline_bounce(job).await;
523
524        match out_rx.recv().await.expect("message") {
525            Message::OfflineBounceFinished { result: Err(err) } => {
526                assert!(err.contains("Track not found: missing"));
527            }
528            other => panic!("unexpected message: {other:?}"),
529        }
530    }
531
532    #[tokio::test]
533    async fn process_offline_bounce_cancels_and_restores_track_state() {
534        let (_rx_unused_tx, rx_unused) = channel(1);
535        let (tx, mut out_rx) = channel(8);
536        let worker = Worker {
537            id: 5,
538            rx: rx_unused,
539            tx,
540        };
541        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 4, 48_000.0);
542        track.set_level(-9.0);
543        track.set_balance(-0.3);
544        let state = make_state_with_track(track);
545        let job = OfflineBounceWork {
546            state: state.clone(),
547            track_name: "track".to_string(),
548            output_path: unique_temp_wav("cancel").to_string_lossy().to_string(),
549            start_sample: 0,
550            length_samples: 8,
551            tempo_bpm: 120.0,
552            tsig_num: 4,
553            tsig_denom: 4,
554            automation_lanes: vec![],
555            cancel: Arc::new(AtomicBool::new(true)),
556        };
557
558        worker.process_offline_bounce(job).await;
559
560        match out_rx.recv().await.expect("message") {
561            Message::OfflineBounceFinished {
562                result: Ok(Action::TrackOfflineBounceCanceled { track_name }),
563            } => assert_eq!(track_name, "track"),
564            other => panic!("unexpected message: {other:?}"),
565        }
566        assert!(matches!(out_rx.recv().await, Some(Message::Ready(5))));
567        let track = state.lock().tracks.get("track").expect("track").lock();
568        assert_eq!(track.level(), -9.0);
569        assert_eq!(track.balance, -0.3);
570    }
571
572    #[tokio::test]
573    async fn process_offline_bounce_restores_track_state_on_write_failure() {
574        let (_rx_unused_tx, rx_unused) = channel(1);
575        let (tx, mut out_rx) = channel(8);
576        let worker = Worker {
577            id: 3,
578            rx: rx_unused,
579            tx,
580        };
581        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 4, 48_000.0);
582        track.set_level(-4.0);
583        track.set_balance(0.25);
584        let state = make_state_with_track(track);
585        let output_path = std::env::temp_dir().to_string_lossy().to_string();
586        let job = OfflineBounceWork {
587            state: state.clone(),
588            track_name: "track".to_string(),
589            output_path,
590            start_sample: 0,
591            length_samples: 4,
592            tempo_bpm: 120.0,
593            tsig_num: 4,
594            tsig_denom: 4,
595            automation_lanes: vec![],
596            cancel: Arc::new(AtomicBool::new(false)),
597        };
598
599        worker.process_offline_bounce(job).await;
600
601        let mut saw_error = false;
602        while let Some(message) = out_rx.recv().await {
603            match message {
604                Message::OfflineBounceFinished {
605                    result: Ok(Action::TrackOfflineBounceProgress { .. }),
606                } => {}
607                Message::OfflineBounceFinished { result: Err(err) } => {
608                    assert!(err.contains("Failed to write offline bounce"));
609                    saw_error = true;
610                }
611                Message::Ready(3) => break,
612                other => panic!("unexpected message: {other:?}"),
613            }
614        }
615        assert!(saw_error);
616        let track = state.lock().tracks.get("track").expect("track").lock();
617        assert_eq!(track.level(), -4.0);
618        assert_eq!(track.balance, 0.25);
619    }
620
621    #[tokio::test]
622    async fn process_offline_bounce_emits_progress_and_completion() {
623        let (_rx_unused_tx, rx_unused) = channel(1);
624        let (tx, mut out_rx) = channel(16);
625        let worker = Worker {
626            id: 2,
627            rx: rx_unused,
628            tx,
629        };
630        let mut track = Track::new("track".to_string(), 1, 1, 0, 0, 4, 48_000.0);
631        track.set_level(-3.0);
632        track.set_balance(0.4);
633        let state = make_state_with_track(track);
634        let output = unique_temp_wav("success");
635        let job = OfflineBounceWork {
636            state: state.clone(),
637            track_name: "track".to_string(),
638            output_path: output.to_string_lossy().to_string(),
639            start_sample: 0,
640            length_samples: 8,
641            tempo_bpm: 120.0,
642            tsig_num: 4,
643            tsig_denom: 4,
644            automation_lanes: vec![],
645            cancel: Arc::new(AtomicBool::new(false)),
646        };
647
648        worker.process_offline_bounce(job).await;
649
650        let mut saw_progress = false;
651        let mut saw_complete = false;
652        let mut saw_ready = false;
653        while let Some(message) = out_rx.recv().await {
654            match message {
655                Message::OfflineBounceFinished {
656                    result:
657                        Ok(Action::TrackOfflineBounceProgress {
658                            track_name,
659                            progress,
660                            ..
661                        }),
662                } => {
663                    assert_eq!(track_name, "track");
664                    assert!(progress > 0.0);
665                    saw_progress = true;
666                }
667                Message::OfflineBounceFinished {
668                    result:
669                        Ok(Action::TrackOfflineBounce {
670                            track_name,
671                            output_path,
672                            ..
673                        }),
674                } => {
675                    assert_eq!(track_name, "track");
676                    assert_eq!(output_path, output.to_string_lossy());
677                    saw_complete = true;
678                }
679                Message::Ready(2) => {
680                    saw_ready = true;
681                    break;
682                }
683                other => panic!("unexpected message: {other:?}"),
684            }
685        }
686
687        assert!(saw_progress);
688        assert!(saw_complete);
689        assert!(saw_ready);
690        assert!(output.exists());
691        std::fs::remove_file(&output).expect("remove temp wav");
692        let track = state.lock().tracks.get("track").expect("track").lock();
693        assert_eq!(track.level(), -3.0);
694        assert_eq!(track.balance, 0.4);
695        assert!(!track.muted);
696    }
697}