Skip to main content

rune_framework/
caster.rs

1//! Caster — connects to Rune runtime and executes registered handlers.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use bytes::Bytes;
8use tokio::sync::mpsc;
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tonic::transport::Channel;
12
13use crate::proto::rune_service_client::RuneServiceClient;
14use crate::proto::{
15    session_message::Payload, CasterAttach, ErrorDetail, ExecuteResult,
16    GateConfig as ProtoGateConfig, Heartbeat, RuneDeclaration, SessionMessage, StreamEnd,
17    StreamEvent,
18};
19
20use crate::config::{CasterConfig, FileAttachment, RuneConfig};
21use crate::context::RuneContext;
22use crate::error::{SdkError, SdkResult};
23use crate::handler::{BoxFuture, HandlerKind, RegisteredRune};
24use crate::stream::StreamSender;
25
26/// Caster connects to a Rune Runtime and registers Rune handlers.
27pub struct Caster {
28    config: CasterConfig,
29    caster_id: String,
30    runes: Arc<RwLock<HashMap<String, RegisteredRune>>>,
31    shutdown_token: CancellationToken,
32}
33
34impl Caster {
35    /// Create a new Caster with the given configuration.
36    pub fn new(config: CasterConfig) -> Self {
37        let caster_id = config
38            .caster_id
39            .clone()
40            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
41        Self {
42            config,
43            caster_id,
44            runes: Arc::new(RwLock::new(HashMap::new())),
45            shutdown_token: CancellationToken::new(),
46        }
47    }
48
49    /// Returns the caster ID.
50    pub fn caster_id(&self) -> &str {
51        &self.caster_id
52    }
53
54    /// Returns a reference to the configuration.
55    pub fn config(&self) -> &CasterConfig {
56        &self.config
57    }
58
59    /// Returns the number of registered runes.
60    pub fn rune_count(&self) -> usize {
61        self.runes.read().unwrap().len()
62    }
63
64    /// Returns the config of a registered rune by name.
65    pub fn get_rune_config(&self, name: &str) -> Option<RuneConfig> {
66        self.runes.read().unwrap().get(name).map(|r| r.config.clone())
67    }
68
69    /// Check if a rune is registered as a stream handler.
70    pub fn is_stream_rune(&self, name: &str) -> bool {
71        self.runes
72            .read()
73            .unwrap()
74            .get(name)
75            .map(|r| r.handler.is_stream())
76            .unwrap_or(false)
77    }
78
79    /// Signal the Caster to stop its run loop.
80    ///
81    /// Safe to call from any thread or task. The [`run()`](Self::run) method
82    /// will return shortly after this is called. Idempotent — calling
83    /// multiple times is safe.
84    pub fn stop(&self) {
85        self.shutdown_token.cancel();
86    }
87
88    /// Check if a rune handler accepts file attachments.
89    pub fn rune_accepts_files(&self, name: &str) -> bool {
90        self.runes
91            .read()
92            .unwrap()
93            .get(name)
94            .map(|r| r.handler.accepts_files())
95            .unwrap_or(false)
96    }
97
98    // -----------------------------------------------------------------------
99    // Registration
100    // -----------------------------------------------------------------------
101
102    /// Register a unary rune handler.
103    ///
104    /// The handler receives `(RuneContext, Bytes)` and returns `Result<Bytes>`.
105    ///
106    /// # Errors
107    /// Returns `SdkError::DuplicateRune` if a rune with the same name already exists.
108    pub fn rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
109    where
110        F: Fn(RuneContext, Bytes) -> Fut + Send + Sync + 'static,
111        Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
112    {
113        let handler = Arc::new(move |ctx, input| -> BoxFuture<'static, SdkResult<Bytes>> {
114            Box::pin(handler(ctx, input))
115        });
116        self.register_inner(config, HandlerKind::Unary(handler))
117    }
118
119    /// Register a unary rune handler that accepts file attachments.
120    pub fn rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
121    where
122        F: Fn(RuneContext, Bytes, Vec<FileAttachment>) -> Fut + Send + Sync + 'static,
123        Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
124    {
125        let handler =
126            Arc::new(
127                move |ctx, input, files| -> BoxFuture<'static, SdkResult<Bytes>> {
128                    Box::pin(handler(ctx, input, files))
129                },
130            );
131        self.register_inner(config, HandlerKind::UnaryWithFiles(handler))
132    }
133
134    /// Register a streaming rune handler.
135    ///
136    /// The handler receives `(RuneContext, Bytes, StreamSender)` and returns `Result<()>`.
137    pub fn stream_rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
138    where
139        F: Fn(RuneContext, Bytes, StreamSender) -> Fut + Send + Sync + 'static,
140        Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
141    {
142        let handler =
143            Arc::new(move |ctx, input, stream| -> BoxFuture<'static, SdkResult<()>> {
144                Box::pin(handler(ctx, input, stream))
145            });
146        let mut cfg = config;
147        cfg.supports_stream = true;
148        self.register_inner(cfg, HandlerKind::Stream(handler))
149    }
150
151    /// Register a streaming rune handler that accepts file attachments.
152    pub fn stream_rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
153    where
154        F: Fn(RuneContext, Bytes, Vec<FileAttachment>, StreamSender) -> Fut
155            + Send
156            + Sync
157            + 'static,
158        Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
159    {
160        let handler =
161            Arc::new(
162                move |ctx, input, files, stream| -> BoxFuture<'static, SdkResult<()>> {
163                    Box::pin(handler(ctx, input, files, stream))
164                },
165            );
166        let mut cfg = config;
167        cfg.supports_stream = true;
168        self.register_inner(cfg, HandlerKind::StreamWithFiles(handler))
169    }
170
171    fn register_inner(&self, config: RuneConfig, handler: HandlerKind) -> SdkResult<()> {
172        let name = config.name.clone();
173        let registered = RegisteredRune { config, handler };
174        let mut runes = self.runes.write().unwrap();
175        if runes.contains_key(&name) {
176            return Err(SdkError::DuplicateRune(name));
177        }
178        runes.insert(name, registered);
179        Ok(())
180    }
181
182    // -----------------------------------------------------------------------
183    // Run
184    // -----------------------------------------------------------------------
185
186    /// Start the Caster (blocking async). Connects to Runtime with auto-reconnect.
187    ///
188    /// Returns when the session ends normally, or when [`stop()`](Self::stop)
189    /// is called, or on unrecoverable error.
190    pub async fn run(&self) -> SdkResult<()> {
191        let mut delay = Duration::from_secs_f64(self.config.reconnect_base_delay_secs);
192        let max_delay = Duration::from_secs_f64(self.config.reconnect_max_delay_secs);
193
194        loop {
195            if self.shutdown_token.is_cancelled() {
196                return Ok(());
197            }
198            match self.connect_and_run().await {
199                Ok(()) => return Ok(()),
200                Err(e) => {
201                    if self.shutdown_token.is_cancelled() {
202                        return Ok(());
203                    }
204                    tracing::warn!(
205                        "connection error: {}, reconnecting in {:?}",
206                        e,
207                        delay
208                    );
209                    tokio::select! {
210                        _ = tokio::time::sleep(delay) => {}
211                        _ = self.shutdown_token.cancelled() => {
212                            return Ok(());
213                        }
214                    }
215                    delay = (delay * 2).min(max_delay);
216                }
217            }
218        }
219    }
220
221    async fn connect_and_run(&self) -> SdkResult<()> {
222        let endpoint = format!("http://{}", self.config.runtime);
223        let channel = Channel::from_shared(endpoint)
224            .map_err(|e| SdkError::InvalidUri(e.to_string()))?
225            .connect()
226            .await?;
227        let mut client = RuneServiceClient::new(channel);
228
229        // Outbound channel
230        let (tx, rx) = mpsc::channel::<SessionMessage>(32);
231        let outbound = ReceiverStream::new(rx);
232        let response = client.session(outbound).await?;
233        let mut inbound = response.into_inner();
234
235        // Send CasterAttach
236        let attach_msg = self.build_attach_message();
237        tx.send(attach_msg)
238            .await
239            .map_err(|e| SdkError::ChannelSend(e.to_string()))?;
240
241        // Start heartbeat
242        let hb_tx = tx.clone();
243        let hb_interval = Duration::from_secs_f64(self.config.heartbeat_interval_secs);
244        let hb_handle = tokio::spawn(async move {
245            loop {
246                tokio::time::sleep(hb_interval).await;
247                let msg = SessionMessage {
248                    payload: Some(Payload::Heartbeat(Heartbeat {
249                        timestamp_ms: std::time::SystemTime::now()
250                            .duration_since(std::time::UNIX_EPOCH)
251                            .unwrap_or_default()
252                            .as_millis() as u64,
253                    })),
254                };
255                if hb_tx.send(msg).await.is_err() {
256                    break;
257                }
258            }
259        });
260
261        // Cancellation tokens per request (use tokio::sync for async spawned tasks)
262        let cancel_tokens: Arc<tokio::sync::RwLock<HashMap<String, CancellationToken>>> =
263            Arc::new(tokio::sync::RwLock::new(HashMap::new()));
264
265        // Message loop
266        loop {
267            let msg = tokio::select! {
268                msg = inbound.message() => {
269                    match msg? {
270                        Some(m) => m,
271                        None => break, // stream ended
272                    }
273                }
274                _ = self.shutdown_token.cancelled() => {
275                    break;
276                }
277            };
278            match msg.payload {
279                Some(Payload::AttachAck(ack)) => {
280                    if ack.accepted {
281                        tracing::info!(
282                            "attached to {}, caster_id={}",
283                            self.config.runtime,
284                            self.caster_id
285                        );
286                    } else {
287                        tracing::error!("attach rejected: {}", ack.reason);
288                        return Err(SdkError::Other(format!(
289                            "attach rejected: {}",
290                            ack.reason
291                        )));
292                    }
293                }
294                Some(Payload::Execute(req)) => {
295                    let registered = self.runes.read().unwrap().get(&req.rune_name).cloned();
296
297                    let token = CancellationToken::new();
298                    cancel_tokens
299                        .write()
300                        .await
301                        .insert(req.request_id.clone(), token.clone());
302
303                    let tx_clone = tx.clone();
304                    let cancel_tokens_clone = cancel_tokens.clone();
305                    let request_id = req.request_id.clone();
306                    tokio::spawn(async move {
307                        execute_handler(registered, req, tx_clone, token).await;
308                        cancel_tokens_clone.write().await.remove(&request_id);
309                    });
310                }
311                Some(Payload::Cancel(cancel)) => {
312                    if let Some(token) =
313                        cancel_tokens.read().await.get(&cancel.request_id)
314                    {
315                        token.cancel();
316                    }
317                    tracing::info!("cancel requested: {}", cancel.request_id);
318                }
319                Some(Payload::Heartbeat(_)) => {
320                    // Server heartbeat — acknowledged silently
321                }
322                _ => {}
323            }
324        }
325
326        hb_handle.abort();
327        Ok(())
328    }
329
330    fn build_attach_message(&self) -> SessionMessage {
331        let runes = self.runes.read().unwrap();
332        let mut declarations = Vec::new();
333
334        for registered in runes.values() {
335            let cfg = &registered.config;
336            let gate = cfg.gate.as_ref().map(|g| ProtoGateConfig {
337                path: g.path.clone(),
338                method: g.method.clone(),
339            });
340            let input_schema = cfg
341                .input_schema
342                .as_ref()
343                .map(|s| serde_json::to_string(s).unwrap_or_default())
344                .unwrap_or_default();
345            let output_schema = cfg
346                .output_schema
347                .as_ref()
348                .map(|s| serde_json::to_string(s).unwrap_or_default())
349                .unwrap_or_default();
350
351            declarations.push(RuneDeclaration {
352                name: cfg.name.clone(),
353                version: cfg.version.clone(),
354                description: cfg.description.clone(),
355                input_schema,
356                output_schema,
357                supports_stream: cfg.supports_stream,
358                gate,
359                priority: cfg.priority,
360            });
361        }
362
363        SessionMessage {
364            payload: Some(Payload::Attach(CasterAttach {
365                caster_id: self.caster_id.clone(),
366                runes: declarations,
367                labels: self.config.labels.clone(),
368                max_concurrent: self.config.max_concurrent,
369                key: self.config.key.clone().unwrap_or_default(),
370            })),
371        }
372    }
373}
374
375// ---------------------------------------------------------------------------
376// Execute dispatch (free function)
377// ---------------------------------------------------------------------------
378
379async fn execute_handler(
380    registered: Option<RegisteredRune>,
381    req: crate::proto::ExecuteRequest,
382    tx: mpsc::Sender<SessionMessage>,
383    cancel_token: CancellationToken,
384) {
385    let request_id = req.request_id.clone();
386
387    let Some(registered) = registered else {
388        let _ = tx
389            .send(SessionMessage {
390                payload: Some(Payload::Result(ExecuteResult {
391                    request_id,
392                    status: crate::proto::Status::Failed.into(),
393                    output: vec![],
394                    error: Some(ErrorDetail {
395                        code: "NOT_FOUND".into(),
396                        message: format!("rune '{}' not found", req.rune_name),
397                        details: vec![],
398                    }),
399                    attachments: vec![],
400                })),
401            })
402            .await;
403        return;
404    };
405
406    let ctx = RuneContext {
407        rune_name: req.rune_name.clone(),
408        request_id: request_id.clone(),
409        context: req.context.clone(),
410        cancellation: cancel_token,
411    };
412
413    let input = Bytes::from(req.input);
414    let files: Vec<FileAttachment> = req
415        .attachments
416        .iter()
417        .map(|a| FileAttachment {
418            filename: a.filename.clone(),
419            data: Bytes::from(a.data.clone()),
420            mime_type: a.mime_type.clone(),
421        })
422        .collect();
423
424    match &registered.handler {
425        HandlerKind::Unary(handler) => {
426            let result = handler(ctx, input).await;
427            let msg = match result {
428                Ok(output) => SessionMessage {
429                    payload: Some(Payload::Result(ExecuteResult {
430                        request_id,
431                        status: crate::proto::Status::Completed.into(),
432                        output: output.to_vec(),
433                        error: None,
434                        attachments: vec![],
435                    })),
436                },
437                Err(e) => SessionMessage {
438                    payload: Some(Payload::Result(ExecuteResult {
439                        request_id,
440                        status: crate::proto::Status::Failed.into(),
441                        output: vec![],
442                        error: Some(ErrorDetail {
443                            code: "EXECUTION_FAILED".into(),
444                            message: e.to_string(),
445                            details: vec![],
446                        }),
447                        attachments: vec![],
448                    })),
449                },
450            };
451            let _ = tx.send(msg).await;
452        }
453        HandlerKind::UnaryWithFiles(handler) => {
454            let result = handler(ctx, input, files).await;
455            let msg = match result {
456                Ok(output) => SessionMessage {
457                    payload: Some(Payload::Result(ExecuteResult {
458                        request_id,
459                        status: crate::proto::Status::Completed.into(),
460                        output: output.to_vec(),
461                        error: None,
462                        attachments: vec![],
463                    })),
464                },
465                Err(e) => SessionMessage {
466                    payload: Some(Payload::Result(ExecuteResult {
467                        request_id,
468                        status: crate::proto::Status::Failed.into(),
469                        output: vec![],
470                        error: Some(ErrorDetail {
471                            code: "EXECUTION_FAILED".into(),
472                            message: e.to_string(),
473                            details: vec![],
474                        }),
475                        attachments: vec![],
476                    })),
477                },
478            };
479            let _ = tx.send(msg).await;
480        }
481        HandlerKind::Stream(handler) => {
482            let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
483            let sender = StreamSender::new(stream_tx);
484
485            // Forward stream events to gRPC
486            let tx_clone = tx.clone();
487            let rid = request_id.clone();
488            let forward_handle = tokio::spawn(async move {
489                while let Some(data) = stream_rx.recv().await {
490                    let msg = SessionMessage {
491                        payload: Some(Payload::StreamEvent(StreamEvent {
492                            request_id: rid.clone(),
493                            data: data.to_vec(),
494                            event_type: String::new(),
495                        })),
496                    };
497                    if tx_clone.send(msg).await.is_err() {
498                        break;
499                    }
500                }
501            });
502
503            let result = handler(ctx, input, sender).await;
504            forward_handle.await.ok();
505
506            let end_msg = match result {
507                Ok(()) => SessionMessage {
508                    payload: Some(Payload::StreamEnd(StreamEnd {
509                        request_id,
510                        status: crate::proto::Status::Completed.into(),
511                        error: None,
512                    })),
513                },
514                Err(e) => SessionMessage {
515                    payload: Some(Payload::StreamEnd(StreamEnd {
516                        request_id,
517                        status: crate::proto::Status::Failed.into(),
518                        error: Some(ErrorDetail {
519                            code: "EXECUTION_FAILED".into(),
520                            message: e.to_string(),
521                            details: vec![],
522                        }),
523                    })),
524                },
525            };
526            let _ = tx.send(end_msg).await;
527        }
528        HandlerKind::StreamWithFiles(handler) => {
529            let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
530            let sender = StreamSender::new(stream_tx);
531
532            let tx_clone = tx.clone();
533            let rid = request_id.clone();
534            let forward_handle = tokio::spawn(async move {
535                while let Some(data) = stream_rx.recv().await {
536                    let msg = SessionMessage {
537                        payload: Some(Payload::StreamEvent(StreamEvent {
538                            request_id: rid.clone(),
539                            data: data.to_vec(),
540                            event_type: String::new(),
541                        })),
542                    };
543                    if tx_clone.send(msg).await.is_err() {
544                        break;
545                    }
546                }
547            });
548
549            let result = handler(ctx, input, files, sender).await;
550            forward_handle.await.ok();
551
552            let end_msg = match result {
553                Ok(()) => SessionMessage {
554                    payload: Some(Payload::StreamEnd(StreamEnd {
555                        request_id,
556                        status: crate::proto::Status::Completed.into(),
557                        error: None,
558                    })),
559                },
560                Err(e) => SessionMessage {
561                    payload: Some(Payload::StreamEnd(StreamEnd {
562                        request_id,
563                        status: crate::proto::Status::Failed.into(),
564                        error: Some(ErrorDetail {
565                            code: "EXECUTION_FAILED".into(),
566                            message: e.to_string(),
567                            details: vec![],
568                        }),
569                    })),
570                },
571            };
572            let _ = tx.send(end_msg).await;
573        }
574    }
575}