Skip to main content

maolan_engine/workers/
worker.rs

1use crate::message::{
2    Action, Message, OfflineAutomationLane, OfflineAutomationPoint, OfflineAutomationTarget,
3    OfflineBounceWork,
4};
5#[cfg(unix)]
6use nix::libc;
7use std::time::Instant;
8use tokio::sync::mpsc::{Receiver, Sender};
9use tracing::error;
10use wavers::write as write_wav;
11
12#[derive(Debug)]
13pub struct Worker {
14    id: usize,
15    rx: Receiver<Message>,
16    tx: Sender<Message>,
17}
18
19impl Worker {
20    fn automation_lane_value_at(points: &[OfflineAutomationPoint], sample: usize) -> Option<f32> {
21        if points.is_empty() {
22            return None;
23        }
24        if sample <= points[0].sample {
25            return Some(points[0].value.clamp(0.0, 1.0));
26        }
27        if sample >= points[points.len().saturating_sub(1)].sample {
28            return Some(points[points.len().saturating_sub(1)].value.clamp(0.0, 1.0));
29        }
30        for segment in points.windows(2) {
31            let left = &segment[0];
32            let right = &segment[1];
33            if sample < left.sample || sample > right.sample {
34                continue;
35            }
36            let span = right.sample.saturating_sub(left.sample).max(1) as f32;
37            let t = (sample.saturating_sub(left.sample) as f32 / span).clamp(0.0, 1.0);
38            return Some((left.value + (right.value - left.value) * t).clamp(0.0, 1.0));
39        }
40        None
41    }
42
43    fn apply_freeze_automation_at_sample(
44        track: &mut crate::track::Track,
45        sample: usize,
46        lanes: &[OfflineAutomationLane],
47    ) {
48        for lane in lanes {
49            if matches!(
50                lane.target,
51                OfflineAutomationTarget::Volume | OfflineAutomationTarget::Balance
52            ) {
53                continue;
54            }
55            let Some(value) = Self::automation_lane_value_at(&lane.points, sample) else {
56                continue;
57            };
58            match lane.target {
59                OfflineAutomationTarget::Volume | OfflineAutomationTarget::Balance => {}
60                OfflineAutomationTarget::Mute => {
61                    track.set_muted(value >= 0.5);
62                }
63                #[cfg(all(unix, not(target_os = "macos")))]
64                OfflineAutomationTarget::Lv2Parameter {
65                    instance_id,
66                    index,
67                    min,
68                    max,
69                } => {
70                    let lo = min.min(max);
71                    let hi = max.max(min);
72                    let param_value = (lo + value * (hi - lo)).clamp(lo, hi);
73                    let _ = track.set_lv2_control_value(instance_id, index, param_value);
74                }
75                OfflineAutomationTarget::Vst3Parameter {
76                    instance_id,
77                    param_id,
78                } => {
79                    let _ = track.set_vst3_parameter(instance_id, param_id, value.clamp(0.0, 1.0));
80                }
81                OfflineAutomationTarget::ClapParameter {
82                    instance_id,
83                    param_id,
84                    min,
85                    max,
86                } => {
87                    let lo = min.min(max);
88                    let hi = max.max(min);
89                    let param_value = (lo + value as f64 * (hi - lo)).clamp(lo, hi);
90                    let _ = track.set_clap_parameter_at(instance_id, param_id, param_value, 0);
91                }
92            }
93        }
94    }
95
96    fn prepare_track_for_freeze_render(track: &mut crate::track::Track) -> (f32, f32) {
97        let original_level = track.level();
98        let original_balance = track.balance;
99        track.set_level(0.0);
100        track.set_balance(0.0);
101        (original_level, original_balance)
102    }
103
104    fn restore_track_after_freeze_render(
105        track: &mut crate::track::Track,
106        original_level: f32,
107        original_balance: f32,
108    ) {
109        track.set_level(original_level);
110        track.set_balance(original_balance);
111    }
112
113    async fn process_offline_bounce(&self, job: OfflineBounceWork) {
114        let track_handle = job.state.lock().tracks.get(&job.track_name).cloned();
115        let Some(target_track) = track_handle else {
116            let _ = self
117                .tx
118                .send(Message::OfflineBounceFinished {
119                    result: Err(format!("Track not found: {}", job.track_name)),
120                })
121                .await;
122            return;
123        };
124        let (channels, block_size, sample_rate) = {
125            let t = target_track.lock();
126            let block_size = t
127                .audio
128                .outs
129                .first()
130                .map(|io| io.buffer.lock().len())
131                .or_else(|| t.audio.ins.first().map(|io| io.buffer.lock().len()))
132                .unwrap_or(0)
133                .max(1);
134            (
135                t.audio.outs.len().max(1),
136                block_size,
137                t.sample_rate.round().max(1.0) as i32,
138            )
139        };
140        let (original_level, original_balance) = {
141            let t = target_track.lock();
142            Self::prepare_track_for_freeze_render(t)
143        };
144
145        let mut rendered = vec![0.0_f32; job.length_samples.saturating_mul(channels)];
146        let mut cursor = 0usize;
147        while cursor < job.length_samples {
148            if job.cancel.load(std::sync::atomic::Ordering::Relaxed) {
149                {
150                    let t = target_track.lock();
151                    Self::restore_track_after_freeze_render(t, original_level, original_balance);
152                }
153                let _ = self
154                    .tx
155                    .send(Message::OfflineBounceFinished {
156                        result: Ok(Action::TrackOfflineBounceCanceled {
157                            track_name: job.track_name.clone(),
158                        }),
159                    })
160                    .await;
161                let _ = self.tx.send(Message::Ready(self.id)).await;
162                return;
163            }
164
165            let step = (job.length_samples - cursor).min(block_size);
166            let tracks: Vec<_> = job.state.lock().tracks.values().cloned().collect();
167            for handle in &tracks {
168                let t = handle.lock();
169                t.audio.finished = false;
170                t.audio.processing = false;
171                t.set_transport_sample(job.start_sample.saturating_add(cursor));
172                t.set_loop_config(false, None);
173                t.set_transport_timing(job.tempo_bpm, job.tsig_num, job.tsig_denom);
174                t.set_clip_playback_enabled(true);
175                t.set_record_tap_enabled(false);
176            }
177
178            let mut remaining = tracks.len();
179            while remaining > 0 {
180                let mut progressed = false;
181                for handle in &tracks {
182                    let t = handle.lock();
183                    if t.audio.finished || t.audio.processing {
184                        continue;
185                    }
186                    if t.audio.ready() {
187                        if t.name == job.track_name {
188                            Self::apply_freeze_automation_at_sample(
189                                t,
190                                job.start_sample.saturating_add(cursor),
191                                &job.automation_lanes,
192                            );
193                        }
194                        t.audio.processing = true;
195                        t.process();
196                        t.audio.processing = false;
197                        progressed = true;
198                        remaining = remaining.saturating_sub(1);
199                    }
200                }
201                if !progressed {
202                    for handle in &tracks {
203                        let t = handle.lock();
204                        if t.audio.finished {
205                            continue;
206                        }
207                        if t.name == job.track_name {
208                            Self::apply_freeze_automation_at_sample(
209                                t,
210                                job.start_sample.saturating_add(cursor),
211                                &job.automation_lanes,
212                            );
213                        }
214                        t.audio.processing = true;
215                        t.process();
216                        t.audio.processing = false;
217                        remaining = remaining.saturating_sub(1);
218                    }
219                }
220            }
221
222            {
223                let t = target_track.lock();
224                for ch in 0..channels {
225                    let out = t.audio.outs[ch].buffer.lock();
226                    let copy_len = step.min(out.len());
227                    for i in 0..copy_len {
228                        let dst = (cursor + i) * channels + ch;
229                        rendered[dst] = out[i];
230                    }
231                }
232            }
233
234            cursor = cursor.saturating_add(step);
235            let _ = self
236                .tx
237                .send(Message::OfflineBounceFinished {
238                    result: Ok(Action::TrackOfflineBounceProgress {
239                        track_name: job.track_name.clone(),
240                        progress: (cursor as f32 / job.length_samples as f32).clamp(0.0, 1.0),
241                        operation: Some("Rendering freeze".to_string()),
242                    }),
243                })
244                .await;
245        }
246
247        if let Err(e) =
248            write_wav::<f32, _>(&job.output_path, &rendered, sample_rate, channels as u16)
249        {
250            {
251                let t = target_track.lock();
252                Self::restore_track_after_freeze_render(t, original_level, original_balance);
253            }
254            let _ = self
255                .tx
256                .send(Message::OfflineBounceFinished {
257                    result: Err(format!(
258                        "Failed to write offline bounce '{}': {e}",
259                        job.output_path
260                    )),
261                })
262                .await;
263            let _ = self.tx.send(Message::Ready(self.id)).await;
264            return;
265        }
266
267        {
268            let t = target_track.lock();
269            Self::restore_track_after_freeze_render(t, original_level, original_balance);
270        }
271
272        let _ = self
273            .tx
274            .send(Message::OfflineBounceFinished {
275                result: Ok(Action::TrackOfflineBounce {
276                    track_name: job.track_name,
277                    output_path: job.output_path,
278                    start_sample: job.start_sample,
279                    length_samples: job.length_samples,
280                    automation_lanes: vec![],
281                }),
282            })
283            .await;
284        let _ = self.tx.send(Message::Ready(self.id)).await;
285    }
286
287    #[cfg(unix)]
288    fn try_enable_realtime() -> Result<(), String> {
289        let thread = unsafe { libc::pthread_self() };
290        let policy = libc::SCHED_FIFO;
291        let param = unsafe {
292            let mut p = std::mem::zeroed::<libc::sched_param>();
293            p.sched_priority = 10;
294            p
295        };
296        let rc = unsafe { libc::pthread_setschedparam(thread, policy, &param) };
297        if rc == 0 {
298            Ok(())
299        } else {
300            Err(format!("pthread_setschedparam failed with errno {}", rc))
301        }
302    }
303
304    #[cfg(not(unix))]
305    fn try_enable_realtime() -> Result<(), String> {
306        Err("Realtime thread priority is not supported on this platform".to_string())
307    }
308
309    pub async fn new(id: usize, rx: Receiver<Message>, tx: Sender<Message>) -> Worker {
310        let worker = Worker { id, rx, tx };
311        worker.send(Message::Ready(id)).await;
312        worker
313    }
314
315    pub async fn send(&self, message: Message) {
316        self.tx
317            .send(message)
318            .await
319            .expect("Failed to send message from worker");
320    }
321
322    pub async fn work(&mut self) {
323        if let Err(e) = Self::try_enable_realtime() {
324            error!("Worker {} realtime priority not enabled: {}", self.id, e);
325        }
326        while let Some(message) = self.rx.recv().await {
327            match message {
328                Message::Request(Action::Quit) => {
329                    return;
330                }
331                Message::ProcessTrack(t) => {
332                    let (track_name, output_linear, process_epoch) = {
333                        let track = t.lock();
334                        let process_epoch = track.process_epoch;
335                        let started = Instant::now();
336                        track.process();
337                        let elapsed = started.elapsed();
338                        if elapsed.as_millis() > 20 {
339                            tracing::warn!(
340                                "Slow track process '{}' took {:.3} ms",
341                                track.name,
342                                elapsed.as_secs_f64() * 1000.0
343                            );
344                        }
345                        track.audio.processing = false;
346                        (
347                            track.name.clone(),
348                            track.output_meter_linear(),
349                            process_epoch,
350                        )
351                    };
352                    match self
353                        .tx
354                        .send(Message::Finished {
355                            worker_id: self.id,
356                            track_name,
357                            output_linear,
358                            process_epoch,
359                        })
360                        .await
361                    {
362                        Ok(_) => {}
363                        Err(e) => {
364                            error!("Error while sending Finished: {}", e);
365                        }
366                    }
367                }
368                Message::ProcessOfflineBounce(job) => {
369                    self.process_offline_bounce(job).await;
370                }
371                _ => {}
372            }
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::Worker;
380    use crate::message::{
381        Action, Message, OfflineAutomationLane, OfflineAutomationPoint, OfflineAutomationTarget,
382        OfflineBounceWork,
383    };
384    use crate::mutex::UnsafeMutex;
385    use crate::state::State;
386    use crate::track::Track;
387    use std::path::PathBuf;
388    use std::sync::{Arc, atomic::AtomicBool};
389    use std::time::{SystemTime, UNIX_EPOCH};
390    use tokio::sync::mpsc::channel;
391
392    fn make_state_with_track(track: Track) -> Arc<UnsafeMutex<State>> {
393        let mut state = State::default();
394        state.tracks.insert(
395            track.name.clone(),
396            Arc::new(UnsafeMutex::new(Box::new(track))),
397        );
398        Arc::new(UnsafeMutex::new(state))
399    }
400
401    fn unique_temp_wav(name: &str) -> PathBuf {
402        let nanos = SystemTime::now()
403            .duration_since(UNIX_EPOCH)
404            .expect("clock")
405            .as_nanos();
406        std::env::temp_dir().join(format!("maolan_{name}_{nanos}.wav"))
407    }
408
409    #[test]
410    fn prepare_track_for_freeze_render_neutralizes_level_and_balance() {
411        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 64, 48_000.0);
412        track.set_level(-6.0);
413        track.set_balance(0.35);
414
415        let (level, balance) = Worker::prepare_track_for_freeze_render(&mut track);
416
417        assert_eq!(level, -6.0);
418        assert_eq!(balance, 0.35);
419        assert_eq!(track.level(), 0.0);
420        assert_eq!(track.balance, 0.0);
421
422        Worker::restore_track_after_freeze_render(&mut track, level, balance);
423        assert_eq!(track.level(), -6.0);
424        assert_eq!(track.balance, 0.35);
425    }
426
427    #[test]
428    fn freeze_automation_ignores_volume_and_balance_lanes() {
429        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 64, 48_000.0);
430        let lanes = vec![
431            OfflineAutomationLane {
432                target: OfflineAutomationTarget::Volume,
433                points: vec![OfflineAutomationPoint {
434                    sample: 0,
435                    value: 0.0,
436                }],
437            },
438            OfflineAutomationLane {
439                target: OfflineAutomationTarget::Balance,
440                points: vec![OfflineAutomationPoint {
441                    sample: 0,
442                    value: 1.0,
443                }],
444            },
445            OfflineAutomationLane {
446                target: OfflineAutomationTarget::Mute,
447                points: vec![OfflineAutomationPoint {
448                    sample: 0,
449                    value: 1.0,
450                }],
451            },
452        ];
453
454        Worker::apply_freeze_automation_at_sample(&mut track, 0, &lanes);
455
456        assert_eq!(track.level(), 0.0);
457        assert_eq!(track.balance, 0.0);
458        assert!(track.muted);
459    }
460
461    #[test]
462    fn automation_lane_value_at_interpolates_between_points() {
463        let value = Worker::automation_lane_value_at(
464            &[
465                OfflineAutomationPoint {
466                    sample: 10,
467                    value: 0.25,
468                },
469                OfflineAutomationPoint {
470                    sample: 20,
471                    value: 0.75,
472                },
473            ],
474            15,
475        )
476        .expect("value");
477
478        assert!((value - 0.5).abs() < 1.0e-6);
479    }
480
481    #[test]
482    fn freeze_automation_applies_interpolated_mute_lane() {
483        let mut track = Track::new("track".to_string(), 1, 1, 0, 0, 64, 48_000.0);
484        let lanes = vec![OfflineAutomationLane {
485            target: OfflineAutomationTarget::Mute,
486            points: vec![
487                OfflineAutomationPoint {
488                    sample: 0,
489                    value: 0.0,
490                },
491                OfflineAutomationPoint {
492                    sample: 10,
493                    value: 1.0,
494                },
495            ],
496        }];
497
498        Worker::apply_freeze_automation_at_sample(&mut track, 5, &lanes);
499        assert!(track.muted);
500
501        track.set_muted(false);
502        Worker::apply_freeze_automation_at_sample(&mut track, 2, &lanes);
503        assert!(!track.muted);
504    }
505
506    #[tokio::test]
507    async fn process_offline_bounce_errors_when_track_is_missing() {
508        let (_rx_unused_tx, rx_unused) = channel(1);
509        let (tx, mut out_rx) = channel(8);
510        let worker = Worker {
511            id: 7,
512            rx: rx_unused,
513            tx,
514        };
515        let job = OfflineBounceWork {
516            state: Arc::new(UnsafeMutex::new(State::default())),
517            track_name: "missing".to_string(),
518            output_path: unique_temp_wav("missing").to_string_lossy().to_string(),
519            start_sample: 0,
520            length_samples: 8,
521            tempo_bpm: 120.0,
522            tsig_num: 4,
523            tsig_denom: 4,
524            automation_lanes: vec![],
525            cancel: Arc::new(AtomicBool::new(false)),
526        };
527
528        worker.process_offline_bounce(job).await;
529
530        match out_rx.recv().await.expect("message") {
531            Message::OfflineBounceFinished { result: Err(err) } => {
532                assert!(err.contains("Track not found: missing"));
533            }
534            other => panic!("unexpected message: {other:?}"),
535        }
536    }
537
538    #[tokio::test]
539    async fn process_offline_bounce_cancels_and_restores_track_state() {
540        let (_rx_unused_tx, rx_unused) = channel(1);
541        let (tx, mut out_rx) = channel(8);
542        let worker = Worker {
543            id: 5,
544            rx: rx_unused,
545            tx,
546        };
547        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 4, 48_000.0);
548        track.set_level(-9.0);
549        track.set_balance(-0.3);
550        let state = make_state_with_track(track);
551        let job = OfflineBounceWork {
552            state: state.clone(),
553            track_name: "track".to_string(),
554            output_path: unique_temp_wav("cancel").to_string_lossy().to_string(),
555            start_sample: 0,
556            length_samples: 8,
557            tempo_bpm: 120.0,
558            tsig_num: 4,
559            tsig_denom: 4,
560            automation_lanes: vec![],
561            cancel: Arc::new(AtomicBool::new(true)),
562        };
563
564        worker.process_offline_bounce(job).await;
565
566        match out_rx.recv().await.expect("message") {
567            Message::OfflineBounceFinished {
568                result: Ok(Action::TrackOfflineBounceCanceled { track_name }),
569            } => assert_eq!(track_name, "track"),
570            other => panic!("unexpected message: {other:?}"),
571        }
572        assert!(matches!(out_rx.recv().await, Some(Message::Ready(5))));
573        let track = state.lock().tracks.get("track").expect("track").lock();
574        assert_eq!(track.level(), -9.0);
575        assert_eq!(track.balance, -0.3);
576    }
577
578    #[tokio::test]
579    async fn process_offline_bounce_restores_track_state_on_write_failure() {
580        let (_rx_unused_tx, rx_unused) = channel(1);
581        let (tx, mut out_rx) = channel(8);
582        let worker = Worker {
583            id: 3,
584            rx: rx_unused,
585            tx,
586        };
587        let mut track = Track::new("track".to_string(), 1, 2, 0, 0, 4, 48_000.0);
588        track.set_level(-4.0);
589        track.set_balance(0.25);
590        let state = make_state_with_track(track);
591        let output_path = std::env::temp_dir().to_string_lossy().to_string();
592        let job = OfflineBounceWork {
593            state: state.clone(),
594            track_name: "track".to_string(),
595            output_path,
596            start_sample: 0,
597            length_samples: 4,
598            tempo_bpm: 120.0,
599            tsig_num: 4,
600            tsig_denom: 4,
601            automation_lanes: vec![],
602            cancel: Arc::new(AtomicBool::new(false)),
603        };
604
605        worker.process_offline_bounce(job).await;
606
607        let mut saw_error = false;
608        while let Some(message) = out_rx.recv().await {
609            match message {
610                Message::OfflineBounceFinished {
611                    result: Ok(Action::TrackOfflineBounceProgress { .. }),
612                } => {}
613                Message::OfflineBounceFinished { result: Err(err) } => {
614                    assert!(err.contains("Failed to write offline bounce"));
615                    saw_error = true;
616                }
617                Message::Ready(3) => break,
618                other => panic!("unexpected message: {other:?}"),
619            }
620        }
621        assert!(saw_error);
622        let track = state.lock().tracks.get("track").expect("track").lock();
623        assert_eq!(track.level(), -4.0);
624        assert_eq!(track.balance, 0.25);
625    }
626
627    #[tokio::test]
628    async fn process_offline_bounce_emits_progress_and_completion() {
629        let (_rx_unused_tx, rx_unused) = channel(1);
630        let (tx, mut out_rx) = channel(16);
631        let worker = Worker {
632            id: 2,
633            rx: rx_unused,
634            tx,
635        };
636        let mut track = Track::new("track".to_string(), 1, 1, 0, 0, 4, 48_000.0);
637        track.set_level(-3.0);
638        track.set_balance(0.4);
639        let state = make_state_with_track(track);
640        let output = unique_temp_wav("success");
641        let job = OfflineBounceWork {
642            state: state.clone(),
643            track_name: "track".to_string(),
644            output_path: output.to_string_lossy().to_string(),
645            start_sample: 0,
646            length_samples: 8,
647            tempo_bpm: 120.0,
648            tsig_num: 4,
649            tsig_denom: 4,
650            automation_lanes: vec![],
651            cancel: Arc::new(AtomicBool::new(false)),
652        };
653
654        worker.process_offline_bounce(job).await;
655
656        let mut saw_progress = false;
657        let mut saw_complete = false;
658        let mut saw_ready = false;
659        while let Some(message) = out_rx.recv().await {
660            match message {
661                Message::OfflineBounceFinished {
662                    result:
663                        Ok(Action::TrackOfflineBounceProgress {
664                            track_name,
665                            progress,
666                            ..
667                        }),
668                } => {
669                    assert_eq!(track_name, "track");
670                    assert!(progress > 0.0);
671                    saw_progress = true;
672                }
673                Message::OfflineBounceFinished {
674                    result:
675                        Ok(Action::TrackOfflineBounce {
676                            track_name,
677                            output_path,
678                            ..
679                        }),
680                } => {
681                    assert_eq!(track_name, "track");
682                    assert_eq!(output_path, output.to_string_lossy());
683                    saw_complete = true;
684                }
685                Message::Ready(2) => {
686                    saw_ready = true;
687                    break;
688                }
689                other => panic!("unexpected message: {other:?}"),
690            }
691        }
692
693        assert!(saw_progress);
694        assert!(saw_complete);
695        assert!(saw_ready);
696        assert!(output.exists());
697        std::fs::remove_file(&output).expect("remove temp wav");
698        let track = state.lock().tracks.get("track").expect("track").lock();
699        assert_eq!(track.level(), -3.0);
700        assert_eq!(track.balance, 0.4);
701        assert!(!track.muted);
702    }
703}