1use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5use std::sync::{Arc, RwLock};
6use std::time::Duration;
7
8use bytes::Bytes;
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tonic::transport::Channel;
13
14use rune_proto::rune_service_client::RuneServiceClient;
15use rune_proto::{
16 session_message::Payload, CasterAttach, ErrorDetail, ExecuteResult,
17 GateConfig as ProtoGateConfig, HealthReport, HealthStatus, Heartbeat, RuneDeclaration,
18 SessionMessage, StreamEnd, StreamEvent,
19};
20
21use crate::config::{CasterConfig, FileAttachment, RuneConfig};
22use crate::context::RuneContext;
23use crate::error::{SdkError, SdkResult};
24use crate::handler::{BoxFuture, HandlerKind, RegisteredRune};
25use crate::pilot_client::PilotClient;
26use crate::stream::StreamSender;
27
28pub struct Caster {
30 config: CasterConfig,
31 caster_id: String,
32 runes: Arc<RwLock<HashMap<String, RegisteredRune>>>,
33 shutdown_token: CancellationToken,
34 active_requests: Arc<AtomicU32>,
35 draining: Arc<AtomicBool>,
37 drain_notify: Arc<tokio::sync::Notify>,
40}
41
42impl Caster {
43 pub fn new(config: CasterConfig) -> Self {
45 let caster_id = config
46 .caster_id
47 .clone()
48 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
49 Self {
50 config,
51 caster_id,
52 runes: Arc::new(RwLock::new(HashMap::new())),
53 shutdown_token: CancellationToken::new(),
54 active_requests: Arc::new(AtomicU32::new(0)),
55 draining: Arc::new(AtomicBool::new(false)),
56 drain_notify: Arc::new(tokio::sync::Notify::new()),
57 }
58 }
59
60 pub fn caster_id(&self) -> &str {
62 &self.caster_id
63 }
64
65 pub fn config(&self) -> &CasterConfig {
67 &self.config
68 }
69
70 pub fn rune_count(&self) -> usize {
72 self.runes.read().unwrap().len()
73 }
74
75 pub fn get_rune_config(&self, name: &str) -> Option<RuneConfig> {
77 self.runes
78 .read()
79 .unwrap()
80 .get(name)
81 .map(|r| r.config.clone())
82 }
83
84 pub fn is_stream_rune(&self, name: &str) -> bool {
86 self.runes
87 .read()
88 .unwrap()
89 .get(name)
90 .map(|r| r.handler.is_stream())
91 .unwrap_or(false)
92 }
93
94 pub fn stop(&self) {
100 self.shutdown_token.cancel();
101 }
102
103 pub fn rune_accepts_files(&self, name: &str) -> bool {
105 self.runes
106 .read()
107 .unwrap()
108 .get(name)
109 .map(|r| r.handler.accepts_files())
110 .unwrap_or(false)
111 }
112
113 pub fn rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
124 where
125 F: Fn(RuneContext, Bytes) -> Fut + Send + Sync + 'static,
126 Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
127 {
128 let handler = Arc::new(move |ctx, input| -> BoxFuture<'static, SdkResult<Bytes>> {
129 Box::pin(handler(ctx, input))
130 });
131 self.register_inner(config, HandlerKind::Unary(handler))
132 }
133
134 pub fn rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
136 where
137 F: Fn(RuneContext, Bytes, Vec<FileAttachment>) -> Fut + Send + Sync + 'static,
138 Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
139 {
140 let handler = Arc::new(
141 move |ctx, input, files| -> BoxFuture<'static, SdkResult<Bytes>> {
142 Box::pin(handler(ctx, input, files))
143 },
144 );
145 self.register_inner(config, HandlerKind::UnaryWithFiles(handler))
146 }
147
148 pub fn stream_rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
152 where
153 F: Fn(RuneContext, Bytes, StreamSender) -> Fut + Send + Sync + 'static,
154 Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
155 {
156 let handler = Arc::new(
157 move |ctx, input, stream| -> BoxFuture<'static, SdkResult<()>> {
158 Box::pin(handler(ctx, input, stream))
159 },
160 );
161 let mut cfg = config;
162 cfg.supports_stream = true;
163 self.register_inner(cfg, HandlerKind::Stream(handler))
164 }
165
166 pub fn stream_rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
168 where
169 F: Fn(RuneContext, Bytes, Vec<FileAttachment>, StreamSender) -> Fut + Send + Sync + 'static,
170 Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
171 {
172 let handler = Arc::new(
173 move |ctx, input, files, stream| -> BoxFuture<'static, SdkResult<()>> {
174 Box::pin(handler(ctx, input, files, stream))
175 },
176 );
177 let mut cfg = config;
178 cfg.supports_stream = true;
179 self.register_inner(cfg, HandlerKind::StreamWithFiles(handler))
180 }
181
182 fn register_inner(&self, config: RuneConfig, handler: HandlerKind) -> SdkResult<()> {
183 let name = config.name.clone();
184 let registered = RegisteredRune { config, handler };
185 let mut runes = self.runes.write().unwrap();
186 if runes.contains_key(&name) {
187 return Err(SdkError::DuplicateRune(name));
188 }
189 runes.insert(name, registered);
190 Ok(())
191 }
192
193 pub async fn run(&self) -> SdkResult<()> {
202 let mut delay = Duration::from_secs_f64(self.config.reconnect_base_delay_secs);
203 let max_delay = Duration::from_secs_f64(self.config.reconnect_max_delay_secs);
204 let mut last_pilot: Option<PilotClient> = None;
205
206 let result = loop {
207 if self.shutdown_token.is_cancelled() {
208 break Ok(());
209 }
210
211 let pilot_id = if let Some(policy) = self.config.scale_policy.as_ref() {
214 match PilotClient::ensure(&self.config.runtime, self.config.key.as_deref()).await {
215 Ok(client) => match client.register(&self.caster_id, policy).await {
216 Ok(()) => {
217 let id = client.pilot_id().to_string();
218 last_pilot = Some(client);
219 Some(id)
220 }
221 Err(e) => {
222 tracing::warn!("pilot registration failed: {e}");
223 last_pilot = Some(client);
224 None
225 }
226 },
227 Err(e) => {
228 tracing::warn!("pilot ensure failed: {e}");
229 None
230 }
231 }
232 } else {
233 None
234 };
235
236 match self.connect_and_run(pilot_id.as_deref()).await {
237 Ok(()) => break Ok(()),
238 Err(SdkError::AttachRejected(reason)) => {
239 tracing::error!("attach permanently rejected: {reason}");
240 break Err(SdkError::AttachRejected(reason));
241 }
242 Err(e) => {
243 if self.shutdown_token.is_cancelled() {
244 break Ok(());
245 }
246 tracing::warn!("connection error: {}, reconnecting in {:?}", e, delay);
247 tokio::select! {
248 _ = tokio::time::sleep(delay) => {}
249 _ = self.shutdown_token.cancelled() => {
250 break Ok(());
251 }
252 }
253 delay = (delay * 2).min(max_delay);
254 }
255 }
256 };
257
258 if let Some(client) = last_pilot {
261 let _ = client.deregister(&self.caster_id).await;
262 }
263
264 result
265 }
266
267 async fn connect_and_run(&self, pilot_id: Option<&str>) -> SdkResult<()> {
268 self.draining.store(false, Ordering::Relaxed);
270 let endpoint = format!("http://{}", self.config.runtime);
271 let channel = Channel::from_shared(endpoint)
272 .map_err(|e| SdkError::InvalidUri(e.to_string()))?
273 .connect()
274 .await?;
275 let mut client = RuneServiceClient::new(channel);
276
277 let (tx, rx) = mpsc::channel::<SessionMessage>(32);
279 let outbound = ReceiverStream::new(rx);
280 let response = client.session(outbound).await?;
281 let mut inbound = response.into_inner();
282
283 let attach_msg = self.build_attach_message(pilot_id);
285 tx.send(attach_msg)
286 .await
287 .map_err(|e| SdkError::ChannelSend(e.to_string()))?;
288
289 let hb_tx = tx.clone();
291 let hb_interval = Duration::from_secs_f64(self.config.heartbeat_interval_secs);
292 let config = self.config.clone();
293 let active_requests = Arc::clone(&self.active_requests);
294 let hb_draining = Arc::clone(&self.draining);
295 let hb_handle = tokio::spawn(async move {
296 loop {
297 tokio::time::sleep(hb_interval).await;
298 let msg = SessionMessage {
299 payload: Some(Payload::Heartbeat(Heartbeat {
300 timestamp_ms: std::time::SystemTime::now()
301 .duration_since(std::time::UNIX_EPOCH)
302 .unwrap_or_default()
303 .as_millis() as u64,
304 })),
305 };
306 if hb_tx.send(msg).await.is_err() {
307 break;
308 }
309 let active = active_requests.load(Ordering::Relaxed);
311 let is_draining = hb_draining.load(Ordering::Relaxed);
312 if hb_tx
313 .send(build_health_report_message(&config, active, is_draining))
314 .await
315 .is_err()
316 {
317 break;
318 }
319 }
320 });
321
322 let cancel_tokens: Arc<tokio::sync::RwLock<HashMap<String, CancellationToken>>> =
324 Arc::new(tokio::sync::RwLock::new(HashMap::new()));
325
326 loop {
328 let msg = tokio::select! {
329 msg = inbound.message() => {
330 match msg? {
331 Some(m) => m,
332 None => break, }
334 }
335 _ = self.shutdown_token.cancelled() => {
336 break;
337 }
338 };
339 match msg.payload {
340 Some(Payload::AttachAck(ack)) => {
341 if ack.accepted {
342 tracing::info!(
343 "attached to {}, caster_id={}",
344 self.config.runtime,
345 self.caster_id
346 );
347 tx.send(build_health_report_message(
349 &self.config,
350 self.active_requests.load(Ordering::Relaxed),
351 self.draining.load(Ordering::Relaxed),
352 ))
353 .await
354 .map_err(|e| SdkError::ChannelSend(e.to_string()))?;
355 } else {
356 tracing::error!("attach rejected: {}", ack.reason);
357 return Err(SdkError::AttachRejected(ack.reason.clone()));
358 }
359 }
360 Some(Payload::Execute(req)) => {
361 if self.draining.load(Ordering::Relaxed) {
363 let _ = tx
364 .send(SessionMessage {
365 payload: Some(Payload::Result(ExecuteResult {
366 request_id: req.request_id,
367 status: rune_proto::Status::Failed.into(),
368 output: vec![],
369 error: Some(ErrorDetail {
370 code: "SHUTTING_DOWN".into(),
371 message: "caster is draining, no new requests accepted"
372 .into(),
373 details: vec![],
374 }),
375 attachments: vec![],
376 })),
377 })
378 .await;
379 continue;
380 }
381 let registered = self.runes.read().unwrap().get(&req.rune_name).cloned();
382 self.active_requests.fetch_add(1, Ordering::Relaxed);
383
384 let token = CancellationToken::new();
385 cancel_tokens
386 .write()
387 .await
388 .insert(req.request_id.clone(), token.clone());
389
390 let tx_clone = tx.clone();
391 let cancel_tokens_clone = cancel_tokens.clone();
392 let request_id = req.request_id.clone();
393 let active_requests = Arc::clone(&self.active_requests);
394 let drain_notify = Arc::clone(&self.drain_notify);
395 tokio::spawn(async move {
396 struct Guard(Arc<std::sync::atomic::AtomicU32>, Arc<tokio::sync::Notify>);
400 impl Drop for Guard {
401 fn drop(&mut self) {
402 if self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) == 1 {
403 self.1.notify_one();
404 }
405 }
406 }
407 let _guard = Guard(active_requests, drain_notify);
408
409 execute_handler(registered, req, tx_clone, token).await;
410 cancel_tokens_clone.write().await.remove(&request_id);
411 });
412 }
413 Some(Payload::Cancel(cancel)) => {
414 if let Some(token) = cancel_tokens.read().await.get(&cancel.request_id) {
415 token.cancel();
416 }
417 tracing::info!("cancel requested: {}", cancel.request_id);
418 }
419 Some(Payload::Heartbeat(_)) => {
420 }
422 Some(Payload::Shutdown(shutdown)) => {
423 tracing::info!(
424 "shutdown requested: {}, grace_period_ms={}",
425 shutdown.reason,
426 shutdown.grace_period_ms
427 );
428 self.draining.store(true, Ordering::Relaxed);
430 let _ = tx
433 .send(build_health_report_message(
434 &self.config,
435 self.active_requests.load(Ordering::Relaxed),
436 true,
437 ))
438 .await;
439 let grace = Duration::from_millis(shutdown.grace_period_ms as u64);
442 let drain_deadline = tokio::time::Instant::now() + grace;
443 while self.active_requests.load(Ordering::Relaxed) > 0 {
444 let remaining =
445 drain_deadline.saturating_duration_since(tokio::time::Instant::now());
446 if remaining.is_zero() {
447 tracing::warn!(
448 "grace period expired with {} active requests remaining",
449 self.active_requests.load(Ordering::Relaxed)
450 );
451 break;
452 }
453 tokio::select! {
456 _ = self.drain_notify.notified() => {}
457 _ = tokio::time::sleep(remaining) => {}
458 }
459 }
460 self.stop();
461 break;
462 }
463 _ => {}
464 }
465 }
466
467 hb_handle.abort();
468 Ok(())
469 }
470
471 fn build_attach_message(&self, pilot_id: Option<&str>) -> SessionMessage {
472 let runes = self.runes.read().unwrap();
473 let mut declarations = Vec::new();
474
475 for registered in runes.values() {
476 let cfg = ®istered.config;
477 let gate = cfg.gate.as_ref().map(|g| ProtoGateConfig {
478 path: g.path.clone(),
479 method: g.method.clone(),
480 });
481 let input_schema = cfg
482 .input_schema
483 .as_ref()
484 .map(|s| serde_json::to_string(s).unwrap_or_default())
485 .unwrap_or_default();
486 let output_schema = cfg
487 .output_schema
488 .as_ref()
489 .map(|s| serde_json::to_string(s).unwrap_or_default())
490 .unwrap_or_default();
491
492 declarations.push(RuneDeclaration {
493 name: cfg.name.clone(),
494 version: cfg.version.clone(),
495 description: cfg.description.clone(),
496 input_schema,
497 output_schema,
498 supports_stream: cfg.supports_stream,
499 gate,
500 priority: cfg.priority,
501 });
502 }
503
504 SessionMessage {
505 payload: Some(Payload::Attach(CasterAttach {
506 caster_id: self.caster_id.clone(),
507 runes: declarations,
508 labels: self.attach_labels(pilot_id),
509 max_concurrent: self.config.max_concurrent,
510 key: self.config.key.clone().unwrap_or_default(),
511 role: "caster".into(),
512 })),
513 }
514 }
515
516 fn attach_labels(&self, pilot_id: Option<&str>) -> HashMap<String, String> {
517 let mut labels = self.config.labels.clone();
518 if let Some(policy) = self.config.scale_policy.as_ref() {
519 labels.insert("group".into(), policy.group.clone());
520 labels.insert("_scale_up".into(), policy.scale_up_threshold.to_string());
521 labels.insert(
522 "_scale_down".into(),
523 policy.scale_down_threshold.to_string(),
524 );
525 labels.insert("_sustained".into(), policy.sustained_secs.to_string());
526 labels.insert("_min".into(), policy.min_replicas.to_string());
527 labels.insert("_max".into(), policy.max_replicas.to_string());
528 labels.insert("_spawn_command".into(), policy.spawn_command.clone());
529 labels.insert("_shutdown_signal".into(), policy.shutdown_signal.clone());
530 if let Some(pilot_id) = pilot_id {
531 labels.insert("_pilot_id".into(), pilot_id.to_string());
532 }
533 }
534 labels
535 }
536}
537
538fn build_health_report_message(
539 config: &CasterConfig,
540 active_requests: u32,
541 draining: bool,
542) -> SessionMessage {
543 let mut metrics = config
544 .load_report
545 .as_ref()
546 .map(|report| report.metrics.clone())
547 .unwrap_or_default();
548 metrics
549 .entry("active_requests".into())
550 .or_insert(active_requests as f64);
551 metrics
552 .entry("max_concurrent".into())
553 .or_insert(config.max_concurrent as f64);
554 metrics
555 .entry("available_permits".into())
556 .or_insert(config.max_concurrent.saturating_sub(active_requests) as f64);
557
558 let computed_pressure = if config.max_concurrent == 0 {
559 0.0
560 } else {
561 active_requests as f64 / config.max_concurrent as f64
562 };
563 let pressure = config
564 .load_report
565 .as_ref()
566 .and_then(|lr| lr.pressure)
567 .unwrap_or(computed_pressure);
568
569 let status = if draining {
570 HealthStatus::Unhealthy
571 } else {
572 HealthStatus::Healthy
573 };
574 SessionMessage {
575 payload: Some(Payload::HealthReport(HealthReport {
576 status: status.into(),
577 active_requests,
578 error_rate: 0.0,
579 custom_info: String::new(),
580 timestamp_ms: std::time::SystemTime::now()
581 .duration_since(std::time::UNIX_EPOCH)
582 .unwrap_or_default()
583 .as_millis() as u64,
584 error_rate_window_secs: 0,
585 pressure,
586 metrics,
587 })),
588 }
589}
590
591async fn execute_handler(
596 registered: Option<RegisteredRune>,
597 req: rune_proto::ExecuteRequest,
598 tx: mpsc::Sender<SessionMessage>,
599 cancel_token: CancellationToken,
600) {
601 let request_id = req.request_id.clone();
602
603 let Some(registered) = registered else {
604 let _ = tx
605 .send(SessionMessage {
606 payload: Some(Payload::Result(ExecuteResult {
607 request_id,
608 status: rune_proto::Status::Failed.into(),
609 output: vec![],
610 error: Some(ErrorDetail {
611 code: "NOT_FOUND".into(),
612 message: format!("rune '{}' not found", req.rune_name),
613 details: vec![],
614 }),
615 attachments: vec![],
616 })),
617 })
618 .await;
619 return;
620 };
621
622 let ctx = RuneContext {
623 rune_name: req.rune_name.clone(),
624 request_id: request_id.clone(),
625 context: req.context.clone(),
626 cancellation: cancel_token,
627 };
628
629 let input = Bytes::from(req.input);
630 let files: Vec<FileAttachment> = req
631 .attachments
632 .iter()
633 .map(|a| FileAttachment {
634 filename: a.filename.clone(),
635 data: Bytes::from(a.data.clone()),
636 mime_type: a.mime_type.clone(),
637 })
638 .collect();
639
640 match ®istered.handler {
641 HandlerKind::Unary(handler) => {
642 let result = handler(ctx, input).await;
643 let msg = match result {
644 Ok(output) => SessionMessage {
645 payload: Some(Payload::Result(ExecuteResult {
646 request_id,
647 status: rune_proto::Status::Completed.into(),
648 output: output.to_vec(),
649 error: None,
650 attachments: vec![],
651 })),
652 },
653 Err(e) => SessionMessage {
654 payload: Some(Payload::Result(ExecuteResult {
655 request_id,
656 status: rune_proto::Status::Failed.into(),
657 output: vec![],
658 error: Some(ErrorDetail {
659 code: "EXECUTION_FAILED".into(),
660 message: e.to_string(),
661 details: vec![],
662 }),
663 attachments: vec![],
664 })),
665 },
666 };
667 let _ = tx.send(msg).await;
668 }
669 HandlerKind::UnaryWithFiles(handler) => {
670 let result = handler(ctx, input, files).await;
671 let msg = match result {
672 Ok(output) => SessionMessage {
673 payload: Some(Payload::Result(ExecuteResult {
674 request_id,
675 status: rune_proto::Status::Completed.into(),
676 output: output.to_vec(),
677 error: None,
678 attachments: vec![],
679 })),
680 },
681 Err(e) => SessionMessage {
682 payload: Some(Payload::Result(ExecuteResult {
683 request_id,
684 status: rune_proto::Status::Failed.into(),
685 output: vec![],
686 error: Some(ErrorDetail {
687 code: "EXECUTION_FAILED".into(),
688 message: e.to_string(),
689 details: vec![],
690 }),
691 attachments: vec![],
692 })),
693 },
694 };
695 let _ = tx.send(msg).await;
696 }
697 HandlerKind::Stream(handler) => {
698 let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
699 let sender = StreamSender::new(stream_tx);
700
701 let tx_clone = tx.clone();
703 let rid = request_id.clone();
704 let forward_handle = tokio::spawn(async move {
705 while let Some(data) = stream_rx.recv().await {
706 let msg = SessionMessage {
707 payload: Some(Payload::StreamEvent(StreamEvent {
708 request_id: rid.clone(),
709 data: data.to_vec(),
710 event_type: String::new(),
711 })),
712 };
713 if tx_clone.send(msg).await.is_err() {
714 break;
715 }
716 }
717 });
718
719 let result = handler(ctx, input, sender).await;
720 forward_handle.await.ok();
721
722 let end_msg = match result {
723 Ok(()) => SessionMessage {
724 payload: Some(Payload::StreamEnd(StreamEnd {
725 request_id,
726 status: rune_proto::Status::Completed.into(),
727 error: None,
728 })),
729 },
730 Err(e) => SessionMessage {
731 payload: Some(Payload::StreamEnd(StreamEnd {
732 request_id,
733 status: rune_proto::Status::Failed.into(),
734 error: Some(ErrorDetail {
735 code: "EXECUTION_FAILED".into(),
736 message: e.to_string(),
737 details: vec![],
738 }),
739 })),
740 },
741 };
742 let _ = tx.send(end_msg).await;
743 }
744 HandlerKind::StreamWithFiles(handler) => {
745 let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
746 let sender = StreamSender::new(stream_tx);
747
748 let tx_clone = tx.clone();
749 let rid = request_id.clone();
750 let forward_handle = tokio::spawn(async move {
751 while let Some(data) = stream_rx.recv().await {
752 let msg = SessionMessage {
753 payload: Some(Payload::StreamEvent(StreamEvent {
754 request_id: rid.clone(),
755 data: data.to_vec(),
756 event_type: String::new(),
757 })),
758 };
759 if tx_clone.send(msg).await.is_err() {
760 break;
761 }
762 }
763 });
764
765 let result = handler(ctx, input, files, sender).await;
766 forward_handle.await.ok();
767
768 let end_msg = match result {
769 Ok(()) => SessionMessage {
770 payload: Some(Payload::StreamEnd(StreamEnd {
771 request_id,
772 status: rune_proto::Status::Completed.into(),
773 error: None,
774 })),
775 },
776 Err(e) => SessionMessage {
777 payload: Some(Payload::StreamEnd(StreamEnd {
778 request_id,
779 status: rune_proto::Status::Failed.into(),
780 error: Some(ErrorDetail {
781 code: "EXECUTION_FAILED".into(),
782 message: e.to_string(),
783 details: vec![],
784 }),
785 })),
786 },
787 };
788 let _ = tx.send(end_msg).await;
789 }
790 }
791}