1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use bytes::Bytes;
8use tokio::sync::mpsc;
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tonic::transport::Channel;
12
13use crate::proto::rune_service_client::RuneServiceClient;
14use crate::proto::{
15 session_message::Payload, CasterAttach, ErrorDetail, ExecuteResult,
16 GateConfig as ProtoGateConfig, Heartbeat, RuneDeclaration, SessionMessage, StreamEnd,
17 StreamEvent,
18};
19
20use crate::config::{CasterConfig, FileAttachment, RuneConfig};
21use crate::context::RuneContext;
22use crate::error::{SdkError, SdkResult};
23use crate::handler::{BoxFuture, HandlerKind, RegisteredRune};
24use crate::stream::StreamSender;
25
26pub struct Caster {
28 config: CasterConfig,
29 caster_id: String,
30 runes: Arc<RwLock<HashMap<String, RegisteredRune>>>,
31 shutdown_token: CancellationToken,
32}
33
34impl Caster {
35 pub fn new(config: CasterConfig) -> Self {
37 let caster_id = config
38 .caster_id
39 .clone()
40 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
41 Self {
42 config,
43 caster_id,
44 runes: Arc::new(RwLock::new(HashMap::new())),
45 shutdown_token: CancellationToken::new(),
46 }
47 }
48
49 pub fn caster_id(&self) -> &str {
51 &self.caster_id
52 }
53
54 pub fn config(&self) -> &CasterConfig {
56 &self.config
57 }
58
59 pub fn rune_count(&self) -> usize {
61 self.runes.read().unwrap().len()
62 }
63
64 pub fn get_rune_config(&self, name: &str) -> Option<RuneConfig> {
66 self.runes.read().unwrap().get(name).map(|r| r.config.clone())
67 }
68
69 pub fn is_stream_rune(&self, name: &str) -> bool {
71 self.runes
72 .read()
73 .unwrap()
74 .get(name)
75 .map(|r| r.handler.is_stream())
76 .unwrap_or(false)
77 }
78
79 pub fn stop(&self) {
85 self.shutdown_token.cancel();
86 }
87
88 pub fn rune_accepts_files(&self, name: &str) -> bool {
90 self.runes
91 .read()
92 .unwrap()
93 .get(name)
94 .map(|r| r.handler.accepts_files())
95 .unwrap_or(false)
96 }
97
98 pub fn rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
109 where
110 F: Fn(RuneContext, Bytes) -> Fut + Send + Sync + 'static,
111 Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
112 {
113 let handler = Arc::new(move |ctx, input| -> BoxFuture<'static, SdkResult<Bytes>> {
114 Box::pin(handler(ctx, input))
115 });
116 self.register_inner(config, HandlerKind::Unary(handler))
117 }
118
119 pub fn rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
121 where
122 F: Fn(RuneContext, Bytes, Vec<FileAttachment>) -> Fut + Send + Sync + 'static,
123 Fut: std::future::Future<Output = SdkResult<Bytes>> + Send + 'static,
124 {
125 let handler =
126 Arc::new(
127 move |ctx, input, files| -> BoxFuture<'static, SdkResult<Bytes>> {
128 Box::pin(handler(ctx, input, files))
129 },
130 );
131 self.register_inner(config, HandlerKind::UnaryWithFiles(handler))
132 }
133
134 pub fn stream_rune<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
138 where
139 F: Fn(RuneContext, Bytes, StreamSender) -> Fut + Send + Sync + 'static,
140 Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
141 {
142 let handler =
143 Arc::new(move |ctx, input, stream| -> BoxFuture<'static, SdkResult<()>> {
144 Box::pin(handler(ctx, input, stream))
145 });
146 let mut cfg = config;
147 cfg.supports_stream = true;
148 self.register_inner(cfg, HandlerKind::Stream(handler))
149 }
150
151 pub fn stream_rune_with_files<F, Fut>(&self, config: RuneConfig, handler: F) -> SdkResult<()>
153 where
154 F: Fn(RuneContext, Bytes, Vec<FileAttachment>, StreamSender) -> Fut
155 + Send
156 + Sync
157 + 'static,
158 Fut: std::future::Future<Output = SdkResult<()>> + Send + 'static,
159 {
160 let handler =
161 Arc::new(
162 move |ctx, input, files, stream| -> BoxFuture<'static, SdkResult<()>> {
163 Box::pin(handler(ctx, input, files, stream))
164 },
165 );
166 let mut cfg = config;
167 cfg.supports_stream = true;
168 self.register_inner(cfg, HandlerKind::StreamWithFiles(handler))
169 }
170
171 fn register_inner(&self, config: RuneConfig, handler: HandlerKind) -> SdkResult<()> {
172 let name = config.name.clone();
173 let registered = RegisteredRune { config, handler };
174 let mut runes = self.runes.write().unwrap();
175 if runes.contains_key(&name) {
176 return Err(SdkError::DuplicateRune(name));
177 }
178 runes.insert(name, registered);
179 Ok(())
180 }
181
182 pub async fn run(&self) -> SdkResult<()> {
191 let mut delay = Duration::from_secs_f64(self.config.reconnect_base_delay_secs);
192 let max_delay = Duration::from_secs_f64(self.config.reconnect_max_delay_secs);
193
194 loop {
195 if self.shutdown_token.is_cancelled() {
196 return Ok(());
197 }
198 match self.connect_and_run().await {
199 Ok(()) => return Ok(()),
200 Err(e) => {
201 if self.shutdown_token.is_cancelled() {
202 return Ok(());
203 }
204 tracing::warn!(
205 "connection error: {}, reconnecting in {:?}",
206 e,
207 delay
208 );
209 tokio::select! {
210 _ = tokio::time::sleep(delay) => {}
211 _ = self.shutdown_token.cancelled() => {
212 return Ok(());
213 }
214 }
215 delay = (delay * 2).min(max_delay);
216 }
217 }
218 }
219 }
220
221 async fn connect_and_run(&self) -> SdkResult<()> {
222 let endpoint = format!("http://{}", self.config.runtime);
223 let channel = Channel::from_shared(endpoint)
224 .map_err(|e| SdkError::InvalidUri(e.to_string()))?
225 .connect()
226 .await?;
227 let mut client = RuneServiceClient::new(channel);
228
229 let (tx, rx) = mpsc::channel::<SessionMessage>(32);
231 let outbound = ReceiverStream::new(rx);
232 let response = client.session(outbound).await?;
233 let mut inbound = response.into_inner();
234
235 let attach_msg = self.build_attach_message();
237 tx.send(attach_msg)
238 .await
239 .map_err(|e| SdkError::ChannelSend(e.to_string()))?;
240
241 let hb_tx = tx.clone();
243 let hb_interval = Duration::from_secs_f64(self.config.heartbeat_interval_secs);
244 let hb_handle = tokio::spawn(async move {
245 loop {
246 tokio::time::sleep(hb_interval).await;
247 let msg = SessionMessage {
248 payload: Some(Payload::Heartbeat(Heartbeat {
249 timestamp_ms: std::time::SystemTime::now()
250 .duration_since(std::time::UNIX_EPOCH)
251 .unwrap_or_default()
252 .as_millis() as u64,
253 })),
254 };
255 if hb_tx.send(msg).await.is_err() {
256 break;
257 }
258 }
259 });
260
261 let cancel_tokens: Arc<tokio::sync::RwLock<HashMap<String, CancellationToken>>> =
263 Arc::new(tokio::sync::RwLock::new(HashMap::new()));
264
265 loop {
267 let msg = tokio::select! {
268 msg = inbound.message() => {
269 match msg? {
270 Some(m) => m,
271 None => break, }
273 }
274 _ = self.shutdown_token.cancelled() => {
275 break;
276 }
277 };
278 match msg.payload {
279 Some(Payload::AttachAck(ack)) => {
280 if ack.accepted {
281 tracing::info!(
282 "attached to {}, caster_id={}",
283 self.config.runtime,
284 self.caster_id
285 );
286 } else {
287 tracing::error!("attach rejected: {}", ack.reason);
288 return Err(SdkError::Other(format!(
289 "attach rejected: {}",
290 ack.reason
291 )));
292 }
293 }
294 Some(Payload::Execute(req)) => {
295 let registered = self.runes.read().unwrap().get(&req.rune_name).cloned();
296
297 let token = CancellationToken::new();
298 cancel_tokens
299 .write()
300 .await
301 .insert(req.request_id.clone(), token.clone());
302
303 let tx_clone = tx.clone();
304 let cancel_tokens_clone = cancel_tokens.clone();
305 let request_id = req.request_id.clone();
306 tokio::spawn(async move {
307 execute_handler(registered, req, tx_clone, token).await;
308 cancel_tokens_clone.write().await.remove(&request_id);
309 });
310 }
311 Some(Payload::Cancel(cancel)) => {
312 if let Some(token) =
313 cancel_tokens.read().await.get(&cancel.request_id)
314 {
315 token.cancel();
316 }
317 tracing::info!("cancel requested: {}", cancel.request_id);
318 }
319 Some(Payload::Heartbeat(_)) => {
320 }
322 _ => {}
323 }
324 }
325
326 hb_handle.abort();
327 Ok(())
328 }
329
330 fn build_attach_message(&self) -> SessionMessage {
331 let runes = self.runes.read().unwrap();
332 let mut declarations = Vec::new();
333
334 for registered in runes.values() {
335 let cfg = ®istered.config;
336 let gate = cfg.gate.as_ref().map(|g| ProtoGateConfig {
337 path: g.path.clone(),
338 method: g.method.clone(),
339 });
340 let input_schema = cfg
341 .input_schema
342 .as_ref()
343 .map(|s| serde_json::to_string(s).unwrap_or_default())
344 .unwrap_or_default();
345 let output_schema = cfg
346 .output_schema
347 .as_ref()
348 .map(|s| serde_json::to_string(s).unwrap_or_default())
349 .unwrap_or_default();
350
351 declarations.push(RuneDeclaration {
352 name: cfg.name.clone(),
353 version: cfg.version.clone(),
354 description: cfg.description.clone(),
355 input_schema,
356 output_schema,
357 supports_stream: cfg.supports_stream,
358 gate,
359 priority: cfg.priority,
360 });
361 }
362
363 SessionMessage {
364 payload: Some(Payload::Attach(CasterAttach {
365 caster_id: self.caster_id.clone(),
366 runes: declarations,
367 labels: self.config.labels.clone(),
368 max_concurrent: self.config.max_concurrent,
369 key: self.config.key.clone().unwrap_or_default(),
370 })),
371 }
372 }
373}
374
375async fn execute_handler(
380 registered: Option<RegisteredRune>,
381 req: crate::proto::ExecuteRequest,
382 tx: mpsc::Sender<SessionMessage>,
383 cancel_token: CancellationToken,
384) {
385 let request_id = req.request_id.clone();
386
387 let Some(registered) = registered else {
388 let _ = tx
389 .send(SessionMessage {
390 payload: Some(Payload::Result(ExecuteResult {
391 request_id,
392 status: crate::proto::Status::Failed.into(),
393 output: vec![],
394 error: Some(ErrorDetail {
395 code: "NOT_FOUND".into(),
396 message: format!("rune '{}' not found", req.rune_name),
397 details: vec![],
398 }),
399 attachments: vec![],
400 })),
401 })
402 .await;
403 return;
404 };
405
406 let ctx = RuneContext {
407 rune_name: req.rune_name.clone(),
408 request_id: request_id.clone(),
409 context: req.context.clone(),
410 cancellation: cancel_token,
411 };
412
413 let input = Bytes::from(req.input);
414 let files: Vec<FileAttachment> = req
415 .attachments
416 .iter()
417 .map(|a| FileAttachment {
418 filename: a.filename.clone(),
419 data: Bytes::from(a.data.clone()),
420 mime_type: a.mime_type.clone(),
421 })
422 .collect();
423
424 match ®istered.handler {
425 HandlerKind::Unary(handler) => {
426 let result = handler(ctx, input).await;
427 let msg = match result {
428 Ok(output) => SessionMessage {
429 payload: Some(Payload::Result(ExecuteResult {
430 request_id,
431 status: crate::proto::Status::Completed.into(),
432 output: output.to_vec(),
433 error: None,
434 attachments: vec![],
435 })),
436 },
437 Err(e) => SessionMessage {
438 payload: Some(Payload::Result(ExecuteResult {
439 request_id,
440 status: crate::proto::Status::Failed.into(),
441 output: vec![],
442 error: Some(ErrorDetail {
443 code: "EXECUTION_FAILED".into(),
444 message: e.to_string(),
445 details: vec![],
446 }),
447 attachments: vec![],
448 })),
449 },
450 };
451 let _ = tx.send(msg).await;
452 }
453 HandlerKind::UnaryWithFiles(handler) => {
454 let result = handler(ctx, input, files).await;
455 let msg = match result {
456 Ok(output) => SessionMessage {
457 payload: Some(Payload::Result(ExecuteResult {
458 request_id,
459 status: crate::proto::Status::Completed.into(),
460 output: output.to_vec(),
461 error: None,
462 attachments: vec![],
463 })),
464 },
465 Err(e) => SessionMessage {
466 payload: Some(Payload::Result(ExecuteResult {
467 request_id,
468 status: crate::proto::Status::Failed.into(),
469 output: vec![],
470 error: Some(ErrorDetail {
471 code: "EXECUTION_FAILED".into(),
472 message: e.to_string(),
473 details: vec![],
474 }),
475 attachments: vec![],
476 })),
477 },
478 };
479 let _ = tx.send(msg).await;
480 }
481 HandlerKind::Stream(handler) => {
482 let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
483 let sender = StreamSender::new(stream_tx);
484
485 let tx_clone = tx.clone();
487 let rid = request_id.clone();
488 let forward_handle = tokio::spawn(async move {
489 while let Some(data) = stream_rx.recv().await {
490 let msg = SessionMessage {
491 payload: Some(Payload::StreamEvent(StreamEvent {
492 request_id: rid.clone(),
493 data: data.to_vec(),
494 event_type: String::new(),
495 })),
496 };
497 if tx_clone.send(msg).await.is_err() {
498 break;
499 }
500 }
501 });
502
503 let result = handler(ctx, input, sender).await;
504 forward_handle.await.ok();
505
506 let end_msg = match result {
507 Ok(()) => SessionMessage {
508 payload: Some(Payload::StreamEnd(StreamEnd {
509 request_id,
510 status: crate::proto::Status::Completed.into(),
511 error: None,
512 })),
513 },
514 Err(e) => SessionMessage {
515 payload: Some(Payload::StreamEnd(StreamEnd {
516 request_id,
517 status: crate::proto::Status::Failed.into(),
518 error: Some(ErrorDetail {
519 code: "EXECUTION_FAILED".into(),
520 message: e.to_string(),
521 details: vec![],
522 }),
523 })),
524 },
525 };
526 let _ = tx.send(end_msg).await;
527 }
528 HandlerKind::StreamWithFiles(handler) => {
529 let (stream_tx, mut stream_rx) = mpsc::channel::<Bytes>(32);
530 let sender = StreamSender::new(stream_tx);
531
532 let tx_clone = tx.clone();
533 let rid = request_id.clone();
534 let forward_handle = tokio::spawn(async move {
535 while let Some(data) = stream_rx.recv().await {
536 let msg = SessionMessage {
537 payload: Some(Payload::StreamEvent(StreamEvent {
538 request_id: rid.clone(),
539 data: data.to_vec(),
540 event_type: String::new(),
541 })),
542 };
543 if tx_clone.send(msg).await.is_err() {
544 break;
545 }
546 }
547 });
548
549 let result = handler(ctx, input, files, sender).await;
550 forward_handle.await.ok();
551
552 let end_msg = match result {
553 Ok(()) => SessionMessage {
554 payload: Some(Payload::StreamEnd(StreamEnd {
555 request_id,
556 status: crate::proto::Status::Completed.into(),
557 error: None,
558 })),
559 },
560 Err(e) => SessionMessage {
561 payload: Some(Payload::StreamEnd(StreamEnd {
562 request_id,
563 status: crate::proto::Status::Failed.into(),
564 error: Some(ErrorDetail {
565 code: "EXECUTION_FAILED".into(),
566 message: e.to_string(),
567 details: vec![],
568 }),
569 })),
570 },
571 };
572 let _ = tx.send(end_msg).await;
573 }
574 }
575}