1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3use std::sync::Arc;
4use std::time::Duration;
5
6use crate::client::{AkribesClient, Inner};
7use crate::error::{AkribesError, Result};
8use crate::models::*;
9
10fn validate_contract(
12 inner: &Inner,
13 script_name: &str,
14 document_keys: Option<&[&str]>,
15) -> Result<()> {
16 if inner.broken_scripts.lock().unwrap().contains(script_name) {
17 return Err(AkribesError::ScriptSchemaChanged {
18 script_name: script_name.to_string(),
19 });
20 }
21
22 if let Some(doc_keys) = document_keys {
23 let schemas = inner.schema_cache.lock().unwrap();
24 if let Some(schema) = schemas.get(script_name) {
25 let expected_docs: Vec<&str> = schema
26 .iter()
27 .filter(|(_, ty)| ty == "document")
28 .map(|(name, _)| name.as_str())
29 .collect();
30 let provided: HashSet<&str> = doc_keys.iter().copied().collect();
31 let missing: Vec<String> = expected_docs
32 .iter()
33 .filter(|n| !provided.contains(**n))
34 .map(|n| n.to_string())
35 .collect();
36 let expected: HashSet<&str> = expected_docs.into_iter().collect();
37 let extra: Vec<String> = doc_keys
38 .iter()
39 .filter(|k| !expected.contains(**k))
40 .map(|k| k.to_string())
41 .collect();
42 if !missing.is_empty() || !extra.is_empty() {
43 return Err(AkribesError::ScriptInputMismatch {
44 script_name: script_name.to_string(),
45 missing,
46 extra,
47 });
48 }
49 }
50 }
51
52 Ok(())
53}
54
55#[derive(Clone, Debug)]
64pub struct ExecutionsClient {
65 pub(crate) inner: Arc<Inner>,
66}
67
68impl ExecutionsClient {
69 pub(crate) fn new(inner: Arc<Inner>) -> Self {
70 Self { inner }
71 }
72
73 fn c(&self) -> AkribesClient {
74 AkribesClient {
75 inner: Arc::clone(&self.inner),
76 }
77 }
78
79 pub async fn resume(
81 &self,
82 execution_id: &str,
83 token: &str,
84 data: serde_json::Value,
85 ) -> Result<serde_json::Value> {
86 let url = format!("{}/executions/{}/resume", self.inner.base_url, execution_id);
87 self.c()
88 .post(
89 &url,
90 &ResumeRequest {
91 token: token.to_string(),
92 data,
93 },
94 )
95 .await
96 }
97
98 pub async fn cancel(&self, execution_id: &str) -> Result<bool> {
100 let url = format!("{}/executions/{}", self.inner.base_url, execution_id);
101 self.c().delete(&url).await
102 }
103
104 pub async fn children(&self, execution_id: &str) -> Result<Vec<ExecutionChildSummary>> {
110 let url = format!(
111 "{}/executions/{}/children",
112 self.inner.base_url, execution_id
113 );
114 self.c().get_list(&url).await
115 }
116
117 pub async fn get(&self, execution_id: &str) -> Result<Option<ExecutionStatus>> {
119 let url = format!("{}/executions/{}", self.inner.base_url, execution_id);
120 self.c().get_opt(&url).await
121 }
122
123 pub async fn tasks(&self, execution_id: &str) -> Result<Option<ExecutionTasksResponse>> {
134 let url = format!(
135 "{}/executions/{}/tasks",
136 self.inner.base_url,
137 urlencoding::encode(execution_id)
138 );
139 self.c().get_opt(&url).await
140 }
141
142 pub async fn get_output(&self, execution_id: &str) -> Result<Option<ExecutionOutput>> {
144 let url = format!("{}/executions/{}/output", self.inner.base_url, execution_id);
145 self.c().get_opt(&url).await
146 }
147
148 pub async fn get_events(
152 &self,
153 execution_id: &str,
154 after_id: Option<i64>,
155 limit: Option<i64>,
156 types: Option<&str>,
157 ) -> Result<Option<ExecutionEvents>> {
158 #[derive(serde::Serialize)]
159 struct Q<'a> {
160 #[serde(skip_serializing_if = "Option::is_none")]
161 after_id: Option<i64>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 limit: Option<i64>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 types: Option<&'a str>,
166 }
167 let base = format!("{}/executions/{}/events", self.inner.base_url, execution_id);
168 let url = AkribesClient::url_with_query(
169 &base,
170 &Q {
171 after_id,
172 limit,
173 types,
174 },
175 );
176 let res = self.c().send(self.c().inner.http.get(&url)).await?;
177 if res.status() == reqwest::StatusCode::NOT_FOUND {
178 return Ok(None);
179 }
180 Ok(Some(crate::client::decode_json(res).await?))
181 }
182
183 pub async fn get_document(&self, document_id: &str) -> Result<Option<DocumentMeta>> {
187 let url = format!(
188 "{}/documents/{}",
189 self.inner.base_url,
190 urlencoding::encode(document_id)
191 );
192 self.c().get_opt(&url).await
193 }
194
195 pub async fn get_document_markdown(&self, document_id: &str) -> Result<String> {
197 let url = format!(
198 "{}/documents/{}/markdown",
199 self.inner.base_url,
200 urlencoding::encode(document_id)
201 );
202 let res = self.c().send(self.c().inner.http.get(&url)).await?;
203 let body: serde_json::Value = crate::client::decode_json(res).await?;
204 match body.get("markdown") {
208 Some(serde_json::Value::String(s)) => Ok(s.clone()),
209 other => Err(AkribesError::Other(format!(
210 "GET /documents/{}/markdown returned a malformed response: \
211 expected a string `markdown` field, got {}",
212 document_id,
213 match other {
214 None => "no `markdown` field".to_string(),
215 Some(v) => format!("{v}"),
216 },
217 ))),
218 }
219 }
220
221 pub async fn get_document_url(&self, document_id: &str) -> Result<String> {
223 let url = format!(
224 "{}/documents/{}/content",
225 self.inner.base_url,
226 urlencoding::encode(document_id)
227 );
228 let resp = self
229 .c()
230 .send(self.c().inner.http.get(&url).header("Accept", "*/*"))
231 .await?;
232 Ok(resp
233 .headers()
234 .get("location")
235 .and_then(|v| v.to_str().ok())
236 .map(|s| s.to_string())
237 .unwrap_or_else(|| resp.url().to_string()))
238 }
239
240 pub async fn reconvert_document(&self, document_id: &str) -> Result<serde_json::Value> {
242 let url = format!(
243 "{}/documents/{}/convert",
244 self.inner.base_url,
245 urlencoding::encode(document_id)
246 );
247 self.c().post(&url, &serde_json::json!({})).await
248 }
249
250 pub async fn run_from_node(
261 &self,
262 project_id: i64,
263 script_name: &str,
264 seed_env: HashMap<String, serde_json::Value>,
265 skip_node_ids: Vec<usize>,
266 channel: Option<&str>,
267 inputs: Option<HashMap<String, serde_json::Value>>,
268 ) -> Result<RunResult> {
269 let scoped = ScopedExecutionsClient::new(Arc::clone(&self.inner), project_id);
270 scoped
271 .run_from(script_name, seed_env, skip_node_ids, channel, inputs, None)
272 .await
273 }
274
275 pub async fn await_execution(
277 &self,
278 execution_id: &str,
279 timeout_ms: Option<u64>,
280 poll_interval_ms: Option<u64>,
281 ) -> Result<ExecutionOutput> {
282 let interval = Duration::from_millis(poll_interval_ms.unwrap_or(500));
283 let deadline = timeout_ms.map(|ms| std::time::Instant::now() + Duration::from_millis(ms));
284
285 loop {
286 if let Some(deadline) = deadline {
287 if std::time::Instant::now() >= deadline {
288 return Err(AkribesError::Timeout {
289 execution_id: Some(execution_id.to_string()),
290 });
291 }
292 }
293
294 let output = self.get_output(execution_id).await?;
295 if let Some(output) = output {
296 match output.status.as_str() {
297 "completed" => return Ok(output),
298 "failed" | "cancelled" => {
299 let msg = output.error.clone().unwrap_or_default();
300 let eid = Some(execution_id.to_string());
301 return Err(match output.error_kind.as_deref() {
302 Some("RateLimit")
306 | Some("ServerError")
307 | Some("ServerError500")
308 | Some("BadGateway502")
309 | Some("ServiceUnavailable503")
310 | Some("GatewayTimeout504")
311 | Some("NetworkError") => {
312 let status = match output.error_kind.as_deref() {
313 Some("ServerError500") => Some(500u16),
314 Some("BadGateway502") => Some(502u16),
315 Some("ServiceUnavailable503") => Some(503u16),
316 Some("GatewayTimeout504") => Some(504u16),
317 Some("RateLimit") => Some(429u16),
318 _ => None,
319 };
320 AkribesError::Transient {
321 message: msg,
322 execution_id: eid,
323 retry_after: None,
324 status,
325 }
326 }
327 Some("AuthError") | Some("TokenLimit") => AkribesError::Fatal {
328 message: msg,
329 execution_id: eid,
330 },
331 _ => AkribesError::Script {
332 message: msg,
333 execution_id: eid,
334 },
335 });
336 }
337 _ => {}
338 }
339 }
340
341 tokio::time::sleep(interval).await;
342 }
343 }
344
345 pub async fn await_result(
350 &self,
351 execution_id: &str,
352 timeout_ms: Option<u64>,
353 poll_interval_ms: Option<u64>,
354 ) -> Result<ExecutionOutput> {
355 self.await_execution(execution_id, timeout_ms, poll_interval_ms)
356 .await
357 }
358}
359
360#[derive(Clone, Debug)]
368pub struct ScopedExecutionsClient {
369 base: ExecutionsClient,
370 project_id: i64,
371}
372
373impl Deref for ScopedExecutionsClient {
374 type Target = ExecutionsClient;
375 fn deref(&self) -> &Self::Target {
376 &self.base
377 }
378}
379
380impl ScopedExecutionsClient {
381 pub(crate) fn new(inner: Arc<Inner>, project_id: i64) -> Self {
382 Self {
383 base: ExecutionsClient { inner },
384 project_id,
385 }
386 }
387
388 fn c(&self) -> AkribesClient {
389 AkribesClient {
390 inner: Arc::clone(&self.base.inner),
391 }
392 }
393
394 fn project_url(&self) -> String {
395 format!("{}/projects/{}", self.base.inner.base_url, self.project_id)
396 }
397
398 fn script_url(&self, name: &str) -> String {
399 format!(
400 "{}/scripts/{}",
401 self.project_url(),
402 urlencoding::encode(name)
403 )
404 }
405
406 pub fn run(&self, script_name: &str) -> RunBuilder {
408 RunBuilder {
409 client: self.c(),
410 inner: Arc::clone(&self.base.inner),
411 project_id: self.project_id,
412 script_name: script_name.to_string(),
413 channel: "production".to_string(),
414 inputs: None,
415 triggered_by: None,
416 breakpoint_lines: None,
417 }
418 }
419
420 pub async fn run_stream(&self, req: RunBuilder) -> Result<crate::sub::run_stream::RunStream> {
430 crate::sub::run_stream::start_run_stream(Arc::clone(&self.base.inner), self.project_id, req)
431 .await
432 }
433
434 pub fn list(&self, script_name: &str) -> ListExecutionsBuilder {
436 ListExecutionsBuilder {
437 client: self.c(),
438 script_url: self.script_url(script_name),
439 status: None,
440 channel: None,
441 limit: None,
442 offset: None,
443 }
444 }
445
446 pub async fn cancel_run(&self, script_name: &str) -> Result<bool> {
448 let url = format!("{}/run", self.script_url(script_name));
449 self.c().delete(&url).await
450 }
451
452 pub async fn cancel_all(&self, script_name: &str) -> Result<bool> {
457 self.cancel_run(script_name).await
458 }
459
460 pub async fn run_from(
462 &self,
463 script_name: &str,
464 seed_env: HashMap<String, serde_json::Value>,
465 skip_node_ids: Vec<usize>,
466 channel: Option<&str>,
467 inputs: Option<HashMap<String, serde_json::Value>>,
468 triggered_by: Option<&str>,
469 ) -> Result<RunResult> {
470 let channel = channel.unwrap_or("draft");
471 let url = format!(
472 "{}/run/from?channel={}",
473 self.script_url(script_name),
474 urlencoding::encode(channel)
475 );
476 let tb = triggered_by
477 .map(|s| s.to_string())
478 .unwrap_or_else(|| self.base.inner.name.clone());
479 self.c()
480 .post(
481 &url,
482 &RunFromRequest {
483 inputs,
484 seed_env,
485 skip_node_ids,
486 triggered_by: Some(tb),
487 },
488 )
489 .await
490 }
491
492 pub async fn get_graph(
494 &self,
495 script_name: &str,
496 version_id: Option<i64>,
497 ) -> Result<GraphResponse> {
498 #[derive(serde::Serialize)]
499 struct Q {
500 #[serde(skip_serializing_if = "Option::is_none")]
501 version: Option<i64>,
502 }
503 let base = format!("{}/graph", self.script_url(script_name));
504 let url = AkribesClient::url_with_query(
505 &base,
506 &Q {
507 version: version_id,
508 },
509 );
510 let res = self.c().send(self.c().inner.http.get(&url)).await?;
511 crate::client::decode_json(res).await
512 }
513
514 pub async fn get_project_cost(
516 &self,
517 since: Option<&str>,
518 until: Option<&str>,
519 ) -> Result<crate::models::ProjectCost> {
520 #[derive(serde::Serialize)]
521 struct Q<'a> {
522 #[serde(skip_serializing_if = "Option::is_none")]
523 since: Option<&'a str>,
524 #[serde(skip_serializing_if = "Option::is_none")]
525 until: Option<&'a str>,
526 }
527 let base = format!("{}/cost", self.project_url());
528 let url = AkribesClient::url_with_query(&base, &Q { since, until });
529 let res = self.c().send(self.c().inner.http.get(&url)).await?;
530 crate::client::decode_json(res).await
531 }
532
533 pub async fn get_cost(&self, script_name: &str) -> Result<CostAggregation> {
535 let url = format!("{}/cost", self.script_url(script_name));
536 self.c().get_opt::<CostAggregation>(&url).await.map(|o| {
537 o.unwrap_or_else(|| CostAggregation {
538 total_executions: 0,
539 total_cost_usd: 0.0,
540 avg_cost_usd: 0.0,
541 total_input_tokens: 0,
542 total_output_tokens: 0,
543 total_tool_tokens: 0,
544 by_version: vec![],
545 })
546 })
547 }
548
549 pub async fn run_with_upload(
551 &self,
552 script_name: &str,
553 files: HashMap<String, (String, Vec<u8>)>,
554 channel: Option<&str>,
555 triggered_by: Option<&str>,
556 ) -> Result<RunResult> {
557 let channel = channel.unwrap_or("production");
558 let url = format!(
559 "{}/run/upload?channel={}",
560 self.script_url(script_name),
561 urlencoding::encode(channel)
562 );
563
564 let mut form = reqwest::multipart::Form::new();
565 for (input_name, (filename, data)) in files {
566 let part = reqwest::multipart::Part::bytes(data)
567 .file_name(filename)
568 .mime_str("application/octet-stream")
569 .expect("valid static MIME type");
570 form = form.part(input_name, part);
571 }
572
573 let tb = triggered_by
574 .map(|s| s.to_string())
575 .unwrap_or_else(|| self.base.inner.name.clone());
576 let meta = serde_json::json!({ "triggered_by": tb });
577 form = form.text("_meta", meta.to_string());
578
579 self.c().post_multipart(&url, form).await
580 }
581
582 pub async fn run_with_s3(
584 &self,
585 script_name: &str,
586 inputs: HashMap<String, S3DocumentRef>,
587 channel: Option<&str>,
588 triggered_by: Option<&str>,
589 ) -> Result<RunResult> {
590 let url = format!("{}/run/s3", self.script_url(script_name));
591 let tb = triggered_by
592 .map(|s| s.to_string())
593 .unwrap_or_else(|| self.base.inner.name.clone());
594 self.c()
595 .post(
596 &url,
597 &RunWithS3Request {
598 inputs,
599 channel: channel.map(|s| s.to_string()),
600 triggered_by: Some(tb),
601 },
602 )
603 .await
604 }
605}
606
607#[derive(Debug, Clone)]
631#[must_use = "a builder does nothing until .execute() is called"]
632pub struct RunBuilder {
633 client: AkribesClient,
634 inner: Arc<Inner>,
635 project_id: i64,
636 script_name: String,
637 channel: String,
638 inputs: Option<HashMap<String, serde_json::Value>>,
639 triggered_by: Option<String>,
640 breakpoint_lines: Option<Vec<usize>>,
641}
642
643impl RunBuilder {
644 pub fn script_name(&self) -> &str {
646 &self.script_name
647 }
648
649 fn script_url(&self) -> String {
650 format!(
651 "{}/projects/{}/scripts/{}",
652 self.inner.base_url,
653 self.project_id,
654 urlencoding::encode(&self.script_name)
655 )
656 }
657
658 pub fn channel(mut self, channel: impl Into<String>) -> Self {
659 self.channel = channel.into();
660 self
661 }
662
663 pub fn inputs(mut self, inputs: HashMap<String, serde_json::Value>) -> Self {
666 match &mut self.inputs {
667 Some(existing) => existing.extend(inputs),
668 None => self.inputs = Some(inputs),
669 }
670 self
671 }
672
673 pub fn input<V: Into<serde_json::Value>>(mut self, name: impl Into<String>, value: V) -> Self {
675 self.inputs
676 .get_or_insert_with(HashMap::new)
677 .insert(name.into(), value.into());
678 self
679 }
680
681 pub fn document(self, name: impl Into<String>, doc_id: impl Into<String>) -> Self {
686 self.input(name, serde_json::Value::String(doc_id.into()))
687 }
688
689 pub fn documents<I, S>(self, name: impl Into<String>, doc_ids: I) -> Self
692 where
693 I: IntoIterator<Item = S>,
694 S: Into<String>,
695 {
696 let arr: Vec<serde_json::Value> = doc_ids
697 .into_iter()
698 .map(|d| serde_json::Value::String(d.into()))
699 .collect();
700 self.input(name, serde_json::Value::Array(arr))
701 }
702
703 pub fn triggered_by(mut self, triggered_by: impl Into<String>) -> Self {
704 self.triggered_by = Some(triggered_by.into());
705 self
706 }
707
708 pub fn breakpoint_lines(mut self, lines: Vec<usize>) -> Self {
709 self.breakpoint_lines = Some(lines);
710 self
711 }
712
713 pub async fn execute(self) -> Result<RunResult> {
714 let input_keys: Vec<&str> = self
715 .inputs
716 .as_ref()
717 .map(|d| d.keys().map(|k| k.as_str()).collect())
718 .unwrap_or_default();
719 validate_contract(
720 &self.inner,
721 &self.script_name,
722 if input_keys.is_empty() {
723 None
724 } else {
725 Some(&input_keys)
726 },
727 )?;
728 let url = format!(
729 "{}/run?channel={}",
730 self.script_url(),
731 urlencoding::encode(&self.channel)
732 );
733 let triggered_by = self.triggered_by.unwrap_or_else(|| self.inner.name.clone());
734 self.client
735 .post(
736 &url,
737 &RunRequest {
738 inputs: self.inputs,
739 triggered_by: Some(triggered_by),
740 breakpoint_lines: self.breakpoint_lines,
741 },
742 )
743 .await
744 }
745
746 pub async fn execute_and_await(
747 self,
748 timeout_ms: Option<u64>,
749 ) -> Result<(String, ExecutionOutput)> {
750 let execs = ExecutionsClient {
751 inner: Arc::clone(&self.inner),
752 };
753 let run = self.execute().await?;
754 let eid = run.execution_id.clone();
755 let output = execs.await_execution(&eid, timeout_ms, None).await?;
756 Ok((eid, output))
757 }
758}
759
760#[derive(Debug, Clone)]
763#[must_use = "a builder does nothing until .fetch() is called"]
764pub struct ListExecutionsBuilder {
765 client: AkribesClient,
766 script_url: String,
767 status: Option<String>,
768 channel: Option<String>,
769 limit: Option<i64>,
770 offset: Option<i64>,
771}
772
773impl ListExecutionsBuilder {
774 pub fn status(mut self, status: impl Into<String>) -> Self {
775 self.status = Some(status.into());
776 self
777 }
778
779 pub fn channel(mut self, channel: impl Into<String>) -> Self {
780 self.channel = Some(channel.into());
781 self
782 }
783
784 pub fn limit(mut self, limit: i64) -> Self {
785 self.limit = Some(limit);
786 self
787 }
788
789 pub fn offset(mut self, offset: i64) -> Self {
790 self.offset = Some(offset);
791 self
792 }
793
794 pub async fn fetch(self) -> Result<Vec<ExecutionStatus>> {
795 #[derive(serde::Serialize)]
796 struct Q<'a> {
797 #[serde(skip_serializing_if = "Option::is_none")]
798 status: Option<&'a str>,
799 #[serde(skip_serializing_if = "Option::is_none")]
800 channel: Option<&'a str>,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 limit: Option<i64>,
803 #[serde(skip_serializing_if = "Option::is_none")]
804 offset: Option<i64>,
805 }
806 let base = format!("{}/executions", self.script_url);
807 let url = AkribesClient::url_with_query(
808 &base,
809 &Q {
810 status: self.status.as_deref(),
811 channel: self.channel.as_deref(),
812 limit: self.limit,
813 offset: self.offset,
814 },
815 );
816 let res = self.client.send(self.client.inner.http.get(&url)).await?;
817 if res.status() == reqwest::StatusCode::NOT_FOUND {
818 return Ok(vec![]);
819 }
820 crate::client::decode_json(res).await
821 }
822}