xs/generators/
generator.rs

1use scru128::Scru128Id;
2use tokio::task::JoinHandle;
3
4use nu_protocol::{ByteStream, ByteStreamType, PipelineData, Signals, Span, Value};
5use std::io::Read;
6use std::sync::atomic::AtomicBool;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::io::AsyncReadExt;
10
11use crate::nu;
12use crate::nu::ReturnOptions;
13use crate::store::{FollowOption, Frame, ReadOptions, Store};
14use serde_json::json;
15
16#[derive(Clone, Debug, serde::Deserialize, Default)]
17pub struct GeneratorScriptOptions {
18    pub duplex: Option<bool>,
19    pub return_options: Option<ReturnOptions>,
20}
21
22#[derive(Clone)]
23pub struct GeneratorLoop {
24    pub topic: String,
25}
26
27#[derive(Clone)]
28pub struct Task {
29    pub id: Scru128Id,
30    pub run_closure: nu_protocol::engine::Closure,
31    pub return_options: Option<ReturnOptions>,
32    pub duplex: bool,
33    pub engine: nu::Engine,
34}
35
36#[cfg_attr(not(test), allow(dead_code))]
37#[derive(Debug, Clone)]
38pub enum GeneratorEventKind {
39    Running,
40    /// output frame flushed; payload is raw bytes
41    Recv {
42        suffix: String,
43        data: Vec<u8>,
44    },
45    Stopped(StopReason),
46    ParseError {
47        message: String,
48    },
49    Shutdown,
50}
51
52#[cfg_attr(not(test), allow(dead_code))]
53#[derive(Debug, Clone)]
54pub struct GeneratorEvent {
55    pub kind: GeneratorEventKind,
56}
57
58#[cfg_attr(not(test), allow(dead_code))]
59#[derive(Debug, Clone)]
60pub enum StopReason {
61    Finished,
62    Error { message: String },
63    Terminate,
64    Update { update_id: Scru128Id },
65}
66
67pub(crate) fn emit_event(
68    store: &Store,
69    loop_ctx: &GeneratorLoop,
70    source_id: Scru128Id,
71    return_opts: Option<&ReturnOptions>,
72    kind: GeneratorEventKind,
73) -> Result<GeneratorEvent, Box<dyn std::error::Error + Send + Sync>> {
74    match &kind {
75        GeneratorEventKind::Running => {
76            store.append(
77                Frame::builder(format!("{topic}.running", topic = loop_ctx.topic))
78                    .meta(json!({ "source_id": source_id.to_string() }))
79                    .build(),
80            )?;
81        }
82
83        GeneratorEventKind::Recv { suffix, data } => {
84            let hash = store.cas_insert_bytes_sync(data)?;
85            store.append(
86                Frame::builder(format!(
87                    "{topic}.{suffix}",
88                    topic = loop_ctx.topic,
89                    suffix = suffix
90                ))
91                .hash(hash)
92                .maybe_ttl(return_opts.and_then(|o| o.ttl.clone()))
93                .meta(json!({ "source_id": source_id.to_string() }))
94                .build(),
95            )?;
96        }
97
98        GeneratorEventKind::Stopped(reason) => {
99            let mut meta = json!({
100                "source_id": source_id.to_string(),
101                "reason": stop_reason_str(reason),
102            });
103            if let StopReason::Update { update_id } = reason {
104                meta["update_id"] = json!(update_id.to_string());
105            }
106            if let StopReason::Error { message } = reason {
107                meta["message"] = json!(message);
108            }
109            store.append(
110                Frame::builder(format!("{topic}.stopped", topic = loop_ctx.topic))
111                    .meta(meta)
112                    .build(),
113            )?;
114        }
115
116        GeneratorEventKind::ParseError { message } => {
117            store.append(
118                Frame::builder(format!("{topic}.parse.error", topic = loop_ctx.topic))
119                    .meta(json!({
120                        "source_id": source_id.to_string(),
121                        "reason": message,
122                    }))
123                    .build(),
124            )?;
125        }
126
127        GeneratorEventKind::Shutdown => {
128            store.append(
129                Frame::builder(format!("{topic}.shutdown", topic = loop_ctx.topic))
130                    .meta(json!({ "source_id": source_id.to_string() }))
131                    .build(),
132            )?;
133        }
134    }
135
136    Ok(GeneratorEvent { kind })
137}
138
139fn stop_reason_str(r: &StopReason) -> &'static str {
140    match r {
141        StopReason::Finished => "finished",
142        StopReason::Error { .. } => "error",
143        StopReason::Terminate => "terminate",
144        StopReason::Update { .. } => "update",
145    }
146}
147
148pub fn spawn(store: Store, engine: nu::Engine, spawn_frame: Frame) -> JoinHandle<()> {
149    tokio::spawn(async move { run(store, engine, spawn_frame).await })
150}
151
152async fn run(store: Store, mut engine: nu::Engine, spawn_frame: Frame) {
153    let pristine = engine.clone();
154    let hash = match spawn_frame.hash.clone() {
155        Some(h) => h,
156        None => return,
157    };
158    let mut reader = match store.cas_reader(hash).await {
159        Ok(r) => r,
160        Err(_) => return,
161    };
162    let mut script = String::new();
163    if reader.read_to_string(&mut script).await.is_err() {
164        return;
165    }
166
167    let loop_ctx = GeneratorLoop {
168        topic: spawn_frame
169            .topic
170            .strip_suffix(".spawn")
171            .unwrap_or(&spawn_frame.topic)
172            .to_string(),
173    };
174
175    let nu_config = match nu::parse_config(&mut engine, &script) {
176        Ok(cfg) => cfg,
177        Err(e) => {
178            let _ = emit_event(
179                &store,
180                &loop_ctx,
181                spawn_frame.id,
182                None,
183                GeneratorEventKind::ParseError {
184                    message: e.to_string(),
185                },
186            );
187            return;
188        }
189    };
190    let opts: GeneratorScriptOptions = nu_config.deserialize_options().unwrap_or_default();
191
192    // Create and set the interrupt signal on the engine state
193    let interrupt = Arc::new(AtomicBool::new(false));
194    engine.state.set_signals(Signals::new(interrupt.clone()));
195
196    let task = Task {
197        id: spawn_frame.id,
198        run_closure: nu_config.run_closure,
199        return_options: opts.return_options,
200        duplex: opts.duplex.unwrap_or(false),
201        engine,
202    };
203
204    run_loop(store, loop_ctx, task, pristine).await;
205}
206
207async fn run_loop(store: Store, loop_ctx: GeneratorLoop, mut task: Task, pristine: nu::Engine) {
208    // Create the first start frame and set up a persistent control subscription
209    let _ = emit_event(
210        &store,
211        &loop_ctx,
212        task.id,
213        task.return_options.as_ref(),
214        GeneratorEventKind::Running,
215    );
216    let start_frame = store
217        .last(&format!("{topic}.running", topic = loop_ctx.topic))
218        .expect("running frame");
219    let mut start_id = start_frame.id;
220
221    let control_rx_options = ReadOptions::builder()
222        .follow(FollowOption::On)
223        .after(start_id)
224        .build();
225
226    let mut control_rx = store.read(control_rx_options).await;
227
228    enum LoopOutcome {
229        Continue,
230        Update(Box<Task>, Scru128Id),
231        Terminate,
232        Error(String),
233    }
234
235    impl core::fmt::Debug for LoopOutcome {
236        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237            match self {
238                LoopOutcome::Continue => write!(f, "Continue"),
239                LoopOutcome::Update(_, id) => f.debug_tuple("Update").field(id).finish(),
240                LoopOutcome::Terminate => write!(f, "Terminate"),
241                LoopOutcome::Error(e) => f.debug_tuple("Error").field(e).finish(),
242            }
243        }
244    }
245
246    impl From<&LoopOutcome> for StopReason {
247        fn from(value: &LoopOutcome) -> Self {
248            match value {
249                LoopOutcome::Continue => StopReason::Finished,
250                LoopOutcome::Update(_, id) => StopReason::Update { update_id: *id },
251                LoopOutcome::Terminate => StopReason::Terminate,
252                LoopOutcome::Error(e) => StopReason::Error { message: e.clone() },
253            }
254        }
255    }
256
257    loop {
258        let input_pipeline = if task.duplex {
259            let options = ReadOptions::builder()
260                .follow(FollowOption::On)
261                .after(start_id)
262                .build();
263            let send_rx = store.read(options).await;
264            build_input_pipeline(store.clone(), &loop_ctx, &task, send_rx).await
265        } else {
266            PipelineData::empty()
267        };
268
269        let (done_tx, done_rx) = tokio::sync::oneshot::channel();
270        spawn_thread(
271            store.clone(),
272            loop_ctx.clone(),
273            task.clone(),
274            input_pipeline,
275            done_tx,
276        );
277
278        let terminate_topic = format!("{topic}.terminate", topic = loop_ctx.topic);
279        let spawn_topic = format!("{topic}.spawn", topic = loop_ctx.topic);
280        tokio::pin!(done_rx);
281
282        let outcome = 'ctrl: loop {
283            tokio::select! {
284                biased;
285                maybe = control_rx.recv() => {
286                    match maybe {
287                        Some(frame) if frame.topic == terminate_topic => {
288                            task.engine.state.signals().trigger();
289                            task.engine.kill_job_by_name(&task.id.to_string());
290                            let _ = (&mut done_rx).await;
291                            break 'ctrl LoopOutcome::Terminate;
292                        }
293                        Some(frame) if frame.topic == spawn_topic => {
294                            if let Some(hash) = frame.hash.clone() {
295                                if let Ok(mut reader) = store.cas_reader(hash).await {
296                                    let mut script = String::new();
297                                    if reader.read_to_string(&mut script).await.is_ok() {
298                                        let mut new_engine = pristine.clone();
299                                        match nu::parse_config(&mut new_engine, &script) {
300                                            Ok(cfg) => {
301                                                let opts: GeneratorScriptOptions = cfg.deserialize_options().unwrap_or_default();
302                                                let interrupt = Arc::new(AtomicBool::new(false));
303                                                new_engine.state.set_signals(Signals::new(interrupt.clone()));
304
305                                                task.engine.state.signals().trigger();
306                                                task.engine.kill_job_by_name(&task.id.to_string());
307                                                let _ = (&mut done_rx).await;
308
309                                                let new_task = Task {
310                                                    id: frame.id,
311                                                    run_closure: cfg.run_closure,
312                                                    return_options: opts.return_options,
313                                                    duplex: opts.duplex.unwrap_or(false),
314                                                    engine: new_engine,
315                                                };
316
317                                                break 'ctrl LoopOutcome::Update(Box::new(new_task), frame.id);
318                                            }
319                                            Err(e) => {
320                                                let _ = emit_event(
321                                                    &store,
322                                                    &loop_ctx,
323                                                    frame.id,
324                                                    None,
325                                                    GeneratorEventKind::ParseError { message: e.to_string() },
326                                                );
327                                            }
328                                        }
329                                    }
330                                }
331                            }
332                        }
333                        Some(_) => {}
334                        None => break 'ctrl LoopOutcome::Error("control".into()),
335                    }
336                }
337                res = &mut done_rx => {
338                    break 'ctrl match res.unwrap_or(Err("thread failed".into())) {
339                        Ok(()) => LoopOutcome::Continue,
340                        Err(e) => LoopOutcome::Error(e),
341                    };
342                }
343            }
344        };
345
346        let reason: StopReason = (&outcome).into();
347        let _ = emit_event(
348            &store,
349            &loop_ctx,
350            task.id,
351            task.return_options.as_ref(),
352            GeneratorEventKind::Stopped(reason.clone()),
353        );
354
355        match outcome {
356            LoopOutcome::Continue => {
357                tokio::time::sleep(Duration::from_secs(1)).await;
358                let _ = emit_event(
359                    &store,
360                    &loop_ctx,
361                    task.id,
362                    task.return_options.as_ref(),
363                    GeneratorEventKind::Running,
364                );
365            }
366            LoopOutcome::Update(new_task, _) => {
367                task = *new_task;
368                let _ = emit_event(
369                    &store,
370                    &loop_ctx,
371                    task.id,
372                    task.return_options.as_ref(),
373                    GeneratorEventKind::Running,
374                );
375            }
376            LoopOutcome::Terminate | LoopOutcome::Error(_) => {
377                let _ = emit_event(
378                    &store,
379                    &loop_ctx,
380                    task.id,
381                    task.return_options.as_ref(),
382                    GeneratorEventKind::Shutdown,
383                );
384                break;
385            }
386        }
387
388        if let Some(f) = store.last(&format!("{topic}.running", topic = loop_ctx.topic)) {
389            start_id = f.id;
390        }
391    }
392}
393
394async fn build_input_pipeline(
395    store: Store,
396    loop_ctx: &GeneratorLoop,
397    task: &Task,
398    rx: tokio::sync::mpsc::Receiver<Frame>,
399) -> PipelineData {
400    let topic = format!("{loop_topic}.send", loop_topic = loop_ctx.topic);
401    let signals = task.engine.state.signals().clone();
402    let mut rx = rx;
403    let iter = std::iter::from_fn(move || loop {
404        if signals.interrupted() {
405            return None;
406        }
407
408        match rx.try_recv() {
409            Ok(frame) => {
410                if frame.topic == topic {
411                    if let Some(hash) = frame.hash {
412                        if let Ok(bytes) = store.cas_read_sync(&hash) {
413                            if let Ok(content) = String::from_utf8(bytes) {
414                                return Some(content);
415                            }
416                        }
417                    }
418                }
419            }
420            Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
421                std::thread::sleep(std::time::Duration::from_millis(10));
422                continue;
423            }
424            Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
425                return None;
426            }
427        }
428    });
429
430    ByteStream::from_iter(
431        iter,
432        Span::unknown(),
433        task.engine.state.signals().clone(),
434        ByteStreamType::Unknown,
435    )
436    .into()
437}
438
439fn spawn_thread(
440    store: Store,
441    loop_ctx: GeneratorLoop,
442    mut task: Task,
443    input_pipeline: PipelineData,
444    done_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
445) {
446    let handle = tokio::runtime::Handle::current();
447    std::thread::spawn(move || {
448        let res = match task.engine.run_closure_in_job(
449            &task.run_closure,
450            None,
451            Some(input_pipeline),
452            task.id.to_string(),
453        ) {
454            Ok(pipeline) => {
455                match pipeline {
456                    PipelineData::Empty => {}
457                    PipelineData::Value(value, _) => {
458                        if let Value::String { val, .. } = value {
459                            let suffix = task
460                                .return_options
461                                .as_ref()
462                                .and_then(|o| o.suffix.clone())
463                                .unwrap_or_else(|| "recv".into());
464                            handle.block_on(async {
465                                let _ = emit_event(
466                                    &store,
467                                    &loop_ctx,
468                                    task.id,
469                                    task.return_options.as_ref(),
470                                    GeneratorEventKind::Recv {
471                                        suffix: suffix.clone(),
472                                        data: val.into_bytes(),
473                                    },
474                                );
475                            });
476                        }
477                    }
478                    PipelineData::ListStream(mut stream, _) => {
479                        while let Some(value) = stream.next_value() {
480                            if let Value::String { val, .. } = value {
481                                let suffix = task
482                                    .return_options
483                                    .as_ref()
484                                    .and_then(|o| o.suffix.clone())
485                                    .unwrap_or_else(|| "recv".into());
486                                handle.block_on(async {
487                                    let _ = emit_event(
488                                        &store,
489                                        &loop_ctx,
490                                        task.id,
491                                        task.return_options.as_ref(),
492                                        GeneratorEventKind::Recv {
493                                            suffix: suffix.clone(),
494                                            data: val.into_bytes(),
495                                        },
496                                    );
497                                });
498                            }
499                        }
500                    }
501                    PipelineData::ByteStream(stream, _) => {
502                        if let Some(mut reader) = stream.reader() {
503                            let suffix = task
504                                .return_options
505                                .as_ref()
506                                .and_then(|o| o.suffix.clone())
507                                .unwrap_or_else(|| "recv".into());
508                            let mut buf = [0u8; 8192];
509                            loop {
510                                match reader.read(&mut buf) {
511                                    Ok(0) => break,
512                                    Ok(n) => {
513                                        let chunk = &buf[..n];
514                                        handle.block_on(async {
515                                            let _ = emit_event(
516                                                &store,
517                                                &loop_ctx,
518                                                task.id,
519                                                task.return_options.as_ref(),
520                                                GeneratorEventKind::Recv {
521                                                    suffix: suffix.clone(),
522                                                    data: chunk.to_vec(),
523                                                },
524                                            );
525                                        });
526                                    }
527                                    Err(_) => break,
528                                }
529                            }
530                        }
531                    }
532                }
533                Ok(())
534            }
535            Err(e) => {
536                let working_set = nu_protocol::engine::StateWorkingSet::new(&task.engine.state);
537                Err(nu_protocol::format_cli_error(None, &working_set, &*e, None))
538            }
539        };
540
541        let _ = done_tx.send(res);
542    });
543}