1use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::mpsc;
13
14use anyhow::Result;
15use async_trait::async_trait;
16use folk_core::config::WorkersConfig;
17use folk_core::runtime::{Runtime, WorkerHandle};
18use tracing::debug;
19
20use crate::bridge;
21use crate::worker;
22
23static NEXT_WORKER_ID: AtomicU32 = AtomicU32::new(1);
24
25pub struct WorkerTxSide {
27 pub task_tx: mpsc::SyncSender<bridge::TaskRequest>,
28 pub ready_rx: mpsc::Receiver<()>,
29}
30
31pub struct ExtensionRuntime {
33 config: WorkersConfig,
34 channels: std::sync::Mutex<Vec<WorkerTxSide>>,
36}
37
38impl ExtensionRuntime {
39 pub fn new(config: WorkersConfig, tx_sides: Vec<WorkerTxSide>) -> Self {
41 Self {
42 config,
43 channels: std::sync::Mutex::new(tx_sides),
44 }
45 }
46
47 #[allow(clippy::unnecessary_wraps)] fn spawn_zts_worker(&self) -> Result<Box<dyn WorkerHandle>> {
50 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
51 let script = &self.config.script;
52
53 let (task_tx, task_rx) = mpsc::sync_channel::<bridge::TaskRequest>(8);
54 let (ready_tx, ready_rx) = mpsc::sync_channel::<()>(1);
55
56 let handle = worker::spawn_zts_worker(worker_id, script.to_string(), task_rx, ready_tx);
57 crate::register_zts_worker(handle);
58
59 debug!(worker_id, "ZTS worker thread spawned");
60
61 Ok(Box::new(ChannelWorkerHandle {
62 worker_id,
63 task_tx: Some(task_tx),
64 ready_rx: Some(ready_rx),
65 }))
66 }
67
68 fn take_preconnected(&self) -> Result<Box<dyn WorkerHandle>> {
70 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
71 let tx_side = self.channels.lock().unwrap().pop().ok_or_else(|| {
72 anyhow::anyhow!("no more pre-connected channels (worker {worker_id})")
73 })?;
74
75 debug!(worker_id, "pre-connected worker channel taken");
76
77 Ok(Box::new(ChannelWorkerHandle {
78 worker_id,
79 task_tx: Some(tx_side.task_tx),
80 ready_rx: Some(tx_side.ready_rx),
81 }))
82 }
83}
84
85#[async_trait]
86impl Runtime for ExtensionRuntime {
87 async fn spawn(&self) -> Result<Box<dyn WorkerHandle>> {
88 let has_preconnected = !self.channels.lock().unwrap().is_empty();
89
90 if has_preconnected {
91 self.take_preconnected()
92 } else if self.config.count > 1 {
93 self.spawn_zts_worker()
94 } else {
95 anyhow::bail!("no workers available and ZTS multi-worker not requested")
96 }
97 }
98}
99
100pub struct ChannelWorkerHandle {
102 worker_id: u32,
103 task_tx: Option<mpsc::SyncSender<bridge::TaskRequest>>,
104 ready_rx: Option<mpsc::Receiver<()>>,
105}
106
107#[async_trait]
108impl WorkerHandle for ChannelWorkerHandle {
109 fn id(&self) -> u32 {
110 self.worker_id
111 }
112
113 async fn ready(&mut self) -> Result<()> {
114 if let Some(rx) = self.ready_rx.take() {
115 tokio::task::spawn_blocking(move || rx.recv())
116 .await
117 .map_err(|e| anyhow::anyhow!("spawn_blocking panicked: {e}"))?
118 .map_err(|_| anyhow::anyhow!("worker died before ready"))?;
119 }
120 Ok(())
121 }
122
123 async fn execute(
124 &mut self,
125 method: &str,
126 payload: serde_json::Value,
127 ) -> Result<serde_json::Value> {
128 let tx = self
129 .task_tx
130 .as_ref()
131 .ok_or_else(|| anyhow::anyhow!("worker terminated"))?
132 .clone();
133
134 let method = method.to_string();
135
136 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
138
139 tx.send(bridge::TaskRequest {
142 method,
143 payload,
144 reply: reply_tx,
145 })
146 .map_err(|_| anyhow::anyhow!("worker process gone"))?;
147
148 reply_rx
150 .await
151 .map_err(|_| anyhow::anyhow!("worker dropped reply"))?
152 }
153
154 async fn terminate(&mut self) -> Result<()> {
155 self.task_tx.take();
156 Ok(())
157 }
158}