xs/generators/
generator.rs

1use scru128::Scru128Id;
2use tokio::task::JoinHandle;
3
4use nu_protocol::{ByteStream, ByteStreamType, PipelineData, Signals, Span, Value};
5use std::io::Read;
6use std::sync::atomic::AtomicBool;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::io::AsyncReadExt;
10
11use crate::nu;
12use crate::nu::ReturnOptions;
13use crate::store::{FollowOption, Frame, ReadOptions, Store};
14use serde_json::json;
15
16#[derive(Clone, Debug, serde::Deserialize, Default)]
17pub struct GeneratorScriptOptions {
18    pub duplex: Option<bool>,
19    pub return_options: Option<ReturnOptions>,
20}
21
22#[derive(Clone)]
23pub struct GeneratorLoop {
24    pub topic: String,
25    pub context_id: Scru128Id,
26}
27
28#[derive(Clone)]
29pub struct Task {
30    pub id: Scru128Id,
31    pub run_closure: nu_protocol::engine::Closure,
32    pub return_options: Option<ReturnOptions>,
33    pub duplex: bool,
34    pub engine: nu::Engine,
35}
36
37#[cfg_attr(not(test), allow(dead_code))]
38#[derive(Debug, Clone)]
39pub enum GeneratorEventKind {
40    Running,
41    /// output frame flushed; payload is raw bytes
42    Recv {
43        suffix: String,
44        data: Vec<u8>,
45    },
46    Stopped(StopReason),
47    ParseError {
48        message: String,
49    },
50    Shutdown,
51}
52
53#[cfg_attr(not(test), allow(dead_code))]
54#[derive(Debug, Clone)]
55pub struct GeneratorEvent {
56    pub kind: GeneratorEventKind,
57}
58
59#[cfg_attr(not(test), allow(dead_code))]
60#[derive(Debug, Clone)]
61pub enum StopReason {
62    Finished,
63    Error { message: String },
64    Terminate,
65    Update { update_id: Scru128Id },
66}
67
68pub(crate) fn emit_event(
69    store: &Store,
70    loop_ctx: &GeneratorLoop,
71    source_id: Scru128Id,
72    return_opts: Option<&ReturnOptions>,
73    kind: GeneratorEventKind,
74) -> Result<GeneratorEvent, Box<dyn std::error::Error + Send + Sync>> {
75    match &kind {
76        GeneratorEventKind::Running => {
77            store.append(
78                Frame::builder(
79                    format!("{topic}.running", topic = loop_ctx.topic),
80                    loop_ctx.context_id,
81                )
82                .meta(json!({ "source_id": source_id.to_string() }))
83                .build(),
84            )?;
85        }
86
87        GeneratorEventKind::Recv { suffix, data } => {
88            let hash = store.cas_insert_bytes_sync(data)?;
89            store.append(
90                Frame::builder(
91                    format!("{topic}.{suffix}", topic = loop_ctx.topic, suffix = suffix),
92                    loop_ctx.context_id,
93                )
94                .hash(hash)
95                .maybe_ttl(return_opts.and_then(|o| o.ttl.clone()))
96                .meta(json!({ "source_id": source_id.to_string() }))
97                .build(),
98            )?;
99        }
100
101        GeneratorEventKind::Stopped(reason) => {
102            let mut meta = json!({
103                "source_id": source_id.to_string(),
104                "reason": stop_reason_str(reason),
105            });
106            if let StopReason::Update { update_id } = reason {
107                meta["update_id"] = json!(update_id.to_string());
108            }
109            if let StopReason::Error { message } = reason {
110                meta["message"] = json!(message);
111            }
112            store.append(
113                Frame::builder(
114                    format!("{topic}.stopped", topic = loop_ctx.topic),
115                    loop_ctx.context_id,
116                )
117                .meta(meta)
118                .build(),
119            )?;
120        }
121
122        GeneratorEventKind::ParseError { message } => {
123            store.append(
124                Frame::builder(
125                    format!("{topic}.parse.error", topic = loop_ctx.topic),
126                    loop_ctx.context_id,
127                )
128                .meta(json!({
129                    "source_id": source_id.to_string(),
130                    "reason": message,
131                }))
132                .build(),
133            )?;
134        }
135
136        GeneratorEventKind::Shutdown => {
137            store.append(
138                Frame::builder(
139                    format!("{topic}.shutdown", topic = loop_ctx.topic),
140                    loop_ctx.context_id,
141                )
142                .meta(json!({ "source_id": source_id.to_string() }))
143                .build(),
144            )?;
145        }
146    }
147
148    Ok(GeneratorEvent { kind })
149}
150
151fn stop_reason_str(r: &StopReason) -> &'static str {
152    match r {
153        StopReason::Finished => "finished",
154        StopReason::Error { .. } => "error",
155        StopReason::Terminate => "terminate",
156        StopReason::Update { .. } => "update",
157    }
158}
159
160pub fn spawn(store: Store, engine: nu::Engine, spawn_frame: Frame) -> JoinHandle<()> {
161    tokio::spawn(async move { run(store, engine, spawn_frame).await })
162}
163
164async fn run(store: Store, mut engine: nu::Engine, spawn_frame: Frame) {
165    let pristine = engine.clone();
166    let hash = match spawn_frame.hash.clone() {
167        Some(h) => h,
168        None => return,
169    };
170    let mut reader = match store.cas_reader(hash).await {
171        Ok(r) => r,
172        Err(_) => return,
173    };
174    let mut script = String::new();
175    if reader.read_to_string(&mut script).await.is_err() {
176        return;
177    }
178
179    let loop_ctx = GeneratorLoop {
180        topic: spawn_frame
181            .topic
182            .strip_suffix(".spawn")
183            .unwrap_or(&spawn_frame.topic)
184            .to_string(),
185        context_id: spawn_frame.context_id,
186    };
187
188    let nu_config = match nu::parse_config(&mut engine, &script) {
189        Ok(cfg) => cfg,
190        Err(e) => {
191            let _ = emit_event(
192                &store,
193                &loop_ctx,
194                spawn_frame.id,
195                None,
196                GeneratorEventKind::ParseError {
197                    message: e.to_string(),
198                },
199            );
200            return;
201        }
202    };
203    let opts: GeneratorScriptOptions = nu_config.deserialize_options().unwrap_or_default();
204
205    // Create and set the interrupt signal on the engine state
206    let interrupt = Arc::new(AtomicBool::new(false));
207    engine.state.set_signals(Signals::new(interrupt.clone()));
208
209    let task = Task {
210        id: spawn_frame.id,
211        run_closure: nu_config.run_closure,
212        return_options: opts.return_options,
213        duplex: opts.duplex.unwrap_or(false),
214        engine,
215    };
216
217    run_loop(store, loop_ctx, task, pristine).await;
218}
219
220async fn run_loop(store: Store, loop_ctx: GeneratorLoop, mut task: Task, pristine: nu::Engine) {
221    // Create the first start frame and set up a persistent control subscription
222    let _ = emit_event(
223        &store,
224        &loop_ctx,
225        task.id,
226        task.return_options.as_ref(),
227        GeneratorEventKind::Running,
228    );
229    let start_frame = store
230        .head(
231            &format!("{topic}.running", topic = loop_ctx.topic),
232            loop_ctx.context_id,
233        )
234        .expect("running frame");
235    let mut start_id = start_frame.id;
236
237    let control_rx_options = ReadOptions::builder()
238        .follow(FollowOption::On)
239        .last_id(start_id)
240        .context_id(loop_ctx.context_id)
241        .build();
242
243    let mut control_rx = store.read(control_rx_options).await;
244
245    enum LoopOutcome {
246        Continue,
247        Update(Box<Task>, Scru128Id),
248        Terminate,
249        Error(String),
250    }
251
252    impl core::fmt::Debug for LoopOutcome {
253        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
254            match self {
255                LoopOutcome::Continue => write!(f, "Continue"),
256                LoopOutcome::Update(_, id) => f.debug_tuple("Update").field(id).finish(),
257                LoopOutcome::Terminate => write!(f, "Terminate"),
258                LoopOutcome::Error(e) => f.debug_tuple("Error").field(e).finish(),
259            }
260        }
261    }
262
263    impl From<&LoopOutcome> for StopReason {
264        fn from(value: &LoopOutcome) -> Self {
265            match value {
266                LoopOutcome::Continue => StopReason::Finished,
267                LoopOutcome::Update(_, id) => StopReason::Update { update_id: *id },
268                LoopOutcome::Terminate => StopReason::Terminate,
269                LoopOutcome::Error(e) => StopReason::Error { message: e.clone() },
270            }
271        }
272    }
273
274    loop {
275        let input_pipeline = if task.duplex {
276            let options = ReadOptions::builder()
277                .follow(FollowOption::On)
278                .last_id(start_id)
279                .context_id(loop_ctx.context_id)
280                .build();
281            let send_rx = store.read(options).await;
282            build_input_pipeline(store.clone(), &loop_ctx, &task, send_rx).await
283        } else {
284            PipelineData::empty()
285        };
286
287        let (done_tx, done_rx) = tokio::sync::oneshot::channel();
288        spawn_thread(
289            store.clone(),
290            loop_ctx.clone(),
291            task.clone(),
292            input_pipeline,
293            done_tx,
294        );
295
296        let terminate_topic = format!("{topic}.terminate", topic = loop_ctx.topic);
297        let spawn_topic = format!("{topic}.spawn", topic = loop_ctx.topic);
298        tokio::pin!(done_rx);
299
300        let outcome = 'ctrl: loop {
301            tokio::select! {
302                biased;
303                maybe = control_rx.recv() => {
304                    match maybe {
305                        Some(frame) if frame.topic == terminate_topic => {
306                            task.engine.state.signals().trigger();
307                            task.engine.kill_job_by_name(&task.id.to_string());
308                            let _ = (&mut done_rx).await;
309                            break 'ctrl LoopOutcome::Terminate;
310                        }
311                        Some(frame) if frame.topic == spawn_topic => {
312                            if let Some(hash) = frame.hash.clone() {
313                                if let Ok(mut reader) = store.cas_reader(hash).await {
314                                    let mut script = String::new();
315                                    if reader.read_to_string(&mut script).await.is_ok() {
316                                        let mut new_engine = pristine.clone();
317                                        match nu::parse_config(&mut new_engine, &script) {
318                                            Ok(cfg) => {
319                                                let opts: GeneratorScriptOptions = cfg.deserialize_options().unwrap_or_default();
320                                                let interrupt = Arc::new(AtomicBool::new(false));
321                                                new_engine.state.set_signals(Signals::new(interrupt.clone()));
322
323                                                task.engine.state.signals().trigger();
324                                                task.engine.kill_job_by_name(&task.id.to_string());
325                                                let _ = (&mut done_rx).await;
326
327                                                let new_task = Task {
328                                                    id: frame.id,
329                                                    run_closure: cfg.run_closure,
330                                                    return_options: opts.return_options,
331                                                    duplex: opts.duplex.unwrap_or(false),
332                                                    engine: new_engine,
333                                                };
334
335                                                break 'ctrl LoopOutcome::Update(Box::new(new_task), frame.id);
336                                            }
337                                            Err(e) => {
338                                                let _ = emit_event(
339                                                    &store,
340                                                    &loop_ctx,
341                                                    frame.id,
342                                                    None,
343                                                    GeneratorEventKind::ParseError { message: e.to_string() },
344                                                );
345                                            }
346                                        }
347                                    }
348                                }
349                            }
350                        }
351                        Some(_) => {}
352                        None => break 'ctrl LoopOutcome::Error("control".into()),
353                    }
354                }
355                res = &mut done_rx => {
356                    break 'ctrl match res.unwrap_or(Err("thread failed".into())) {
357                        Ok(()) => LoopOutcome::Continue,
358                        Err(e) => LoopOutcome::Error(e),
359                    };
360                }
361            }
362        };
363
364        let reason: StopReason = (&outcome).into();
365        let _ = emit_event(
366            &store,
367            &loop_ctx,
368            task.id,
369            task.return_options.as_ref(),
370            GeneratorEventKind::Stopped(reason.clone()),
371        );
372
373        match outcome {
374            LoopOutcome::Continue => {
375                tokio::time::sleep(Duration::from_secs(1)).await;
376                let _ = emit_event(
377                    &store,
378                    &loop_ctx,
379                    task.id,
380                    task.return_options.as_ref(),
381                    GeneratorEventKind::Running,
382                );
383            }
384            LoopOutcome::Update(new_task, _) => {
385                task = *new_task;
386                let _ = emit_event(
387                    &store,
388                    &loop_ctx,
389                    task.id,
390                    task.return_options.as_ref(),
391                    GeneratorEventKind::Running,
392                );
393            }
394            LoopOutcome::Terminate | LoopOutcome::Error(_) => {
395                let _ = emit_event(
396                    &store,
397                    &loop_ctx,
398                    task.id,
399                    task.return_options.as_ref(),
400                    GeneratorEventKind::Shutdown,
401                );
402                break;
403            }
404        }
405
406        if let Some(f) = store.head(
407            &format!("{topic}.running", topic = loop_ctx.topic),
408            loop_ctx.context_id,
409        ) {
410            start_id = f.id;
411        }
412    }
413}
414
415async fn build_input_pipeline(
416    store: Store,
417    loop_ctx: &GeneratorLoop,
418    task: &Task,
419    rx: tokio::sync::mpsc::Receiver<Frame>,
420) -> PipelineData {
421    let topic = format!("{loop_topic}.send", loop_topic = loop_ctx.topic);
422    let signals = task.engine.state.signals().clone();
423    let mut rx = rx;
424    let iter = std::iter::from_fn(move || loop {
425        if signals.interrupted() {
426            return None;
427        }
428
429        match rx.try_recv() {
430            Ok(frame) => {
431                if frame.topic == topic {
432                    if let Some(hash) = frame.hash {
433                        if let Ok(bytes) = store.cas_read_sync(&hash) {
434                            if let Ok(content) = String::from_utf8(bytes) {
435                                return Some(content);
436                            }
437                        }
438                    }
439                }
440            }
441            Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
442                std::thread::sleep(std::time::Duration::from_millis(10));
443                continue;
444            }
445            Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
446                return None;
447            }
448        }
449    });
450
451    ByteStream::from_iter(
452        iter,
453        Span::unknown(),
454        task.engine.state.signals().clone(),
455        ByteStreamType::Unknown,
456    )
457    .into()
458}
459
460fn spawn_thread(
461    store: Store,
462    loop_ctx: GeneratorLoop,
463    mut task: Task,
464    input_pipeline: PipelineData,
465    done_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
466) {
467    let handle = tokio::runtime::Handle::current();
468    std::thread::spawn(move || {
469        let res = match task.engine.run_closure_in_job(
470            &task.run_closure,
471            None,
472            Some(input_pipeline),
473            task.id.to_string(),
474        ) {
475            Ok(pipeline) => {
476                match pipeline {
477                    PipelineData::Empty => {}
478                    PipelineData::Value(value, _) => {
479                        if let Value::String { val, .. } = value {
480                            let suffix = task
481                                .return_options
482                                .as_ref()
483                                .and_then(|o| o.suffix.clone())
484                                .unwrap_or_else(|| "recv".into());
485                            handle.block_on(async {
486                                let _ = emit_event(
487                                    &store,
488                                    &loop_ctx,
489                                    task.id,
490                                    task.return_options.as_ref(),
491                                    GeneratorEventKind::Recv {
492                                        suffix: suffix.clone(),
493                                        data: val.into_bytes(),
494                                    },
495                                );
496                            });
497                        }
498                    }
499                    PipelineData::ListStream(mut stream, _) => {
500                        while let Some(value) = stream.next_value() {
501                            if let Value::String { val, .. } = value {
502                                let suffix = task
503                                    .return_options
504                                    .as_ref()
505                                    .and_then(|o| o.suffix.clone())
506                                    .unwrap_or_else(|| "recv".into());
507                                handle.block_on(async {
508                                    let _ = emit_event(
509                                        &store,
510                                        &loop_ctx,
511                                        task.id,
512                                        task.return_options.as_ref(),
513                                        GeneratorEventKind::Recv {
514                                            suffix: suffix.clone(),
515                                            data: val.into_bytes(),
516                                        },
517                                    );
518                                });
519                            }
520                        }
521                    }
522                    PipelineData::ByteStream(stream, _) => {
523                        if let Some(mut reader) = stream.reader() {
524                            let suffix = task
525                                .return_options
526                                .as_ref()
527                                .and_then(|o| o.suffix.clone())
528                                .unwrap_or_else(|| "recv".into());
529                            let mut buf = [0u8; 8192];
530                            loop {
531                                match reader.read(&mut buf) {
532                                    Ok(0) => break,
533                                    Ok(n) => {
534                                        let chunk = &buf[..n];
535                                        handle.block_on(async {
536                                            let _ = emit_event(
537                                                &store,
538                                                &loop_ctx,
539                                                task.id,
540                                                task.return_options.as_ref(),
541                                                GeneratorEventKind::Recv {
542                                                    suffix: suffix.clone(),
543                                                    data: chunk.to_vec(),
544                                                },
545                                            );
546                                        });
547                                    }
548                                    Err(_) => break,
549                                }
550                            }
551                        }
552                    }
553                }
554                Ok(())
555            }
556            Err(e) => {
557                let working_set = nu_protocol::engine::StateWorkingSet::new(&task.engine.state);
558                Err(nu_protocol::format_cli_error(&working_set, &*e, None))
559            }
560        };
561
562        let _ = done_tx.send(res);
563    });
564}