1use crate::{util, Client, Error};
2use serde::{
3 de::{Error as DeError, Visitor},
4 Deserialize, Deserializer, Serialize, Serializer,
5};
6use std::{
7 collections::HashMap,
8 num::NonZeroU64,
9 ops::{Deref, DerefMut},
10 sync::Arc,
11};
12use tokio::stream::StreamExt;
13
14impl Client {
15 pub fn new_batch(&self) -> crate::Result<BatchBuilder> {
16 if self.data.billing_project.is_none() {
17 Err(Error::Msg(
18 "cannot build batch without a billing project".into(),
19 ))
20 } else {
21 Ok(BatchBuilder {
22 bc: self.clone(),
23 attributes: HashMap::new(),
24 callback: None,
25 jobs: Vec::new(),
26 })
27 }
28 }
29}
30
31#[derive(Debug)]
32pub struct BatchBuilder {
33 bc: Client,
34 attributes: HashMap<String, String>,
35 callback: Option<reqwest::Url>,
36 jobs: Vec<JobSpec>,
37}
38
39#[derive(Debug)]
40pub struct Batch {
41 bc: Client,
42 id: u64,
43 attributes: HashMap<String, String>,
44 jobs: Vec<JobSpec>,
45}
46
47#[derive(Debug, Serialize)]
48struct BatchSpec<'a> {
49 attributes: &'a HashMap<String, String>,
50 billing_project: &'a str,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 callback: Option<&'a reqwest::Url>,
53 n_jobs: usize,
54 token: String,
55}
56
57impl BatchBuilder {
58 pub fn name(&mut self, name: impl Into<String>) -> &mut Self {
59 self.attribute("name", name.into())
60 }
61
62 pub fn attribute(&mut self, key: impl ToString, value: impl ToString) -> &mut Self {
63 self.attributes.insert(key.to_string(), value.to_string());
64 self
65 }
66
67 pub fn attributes<I, S, T>(&mut self, attrs: I) -> &mut Self
68 where
69 S: ToString,
70 T: ToString,
71 I: IntoIterator<Item = (S, T)>,
72 {
73 let attrs = attrs
74 .into_iter()
75 .map(|(k, v)| (k.to_string(), v.to_string()));
76 self.attributes.extend(attrs);
77 self
78 }
79
80 pub fn callback(&mut self, url: impl reqwest::IntoUrl) -> crate::Result<&mut Self> {
81 self.callback.replace(url.into_url()?);
82 Ok(self)
83 }
84
85 pub fn new_job(&mut self, image: impl Into<String>, cmd: impl Into<String>) -> JobBuilder<'_> {
87 JobBuilder::new(self, image.into(), cmd.into())
88 }
89
90 pub fn add_job(&mut self, mut spec: JobSpec) -> crate::Result<JobBuilder<'_>> {
96 spec.id = self.jobs.len() + 1;
97 for &id in &spec.parent_ids {
98 if id >= spec.id {
99 return Err(Error::Msg(
100 format!(
101 "invalid parent id {}, parent ids must be less than job ids (which was {}",
102 id, spec.id
103 )
104 .into(),
105 ));
106 }
107 }
108 Ok(JobBuilder { bb: self, spec })
109 }
110
111 async fn submit_jobs(&self, id: u64, specses: Vec<Vec<Vec<u8>>>) -> crate::Result<()> {
112 use tokio::sync::mpsc::{self, error::TryRecvError};
113 if specses.is_empty() {
114 return Ok(());
115 }
116
117 let endpoint = Arc::new(format!("/api/v1alpha/batches/{}/jobs/create", id));
118
119 let (tx, mut rx) = mpsc::channel(10);
120 let n_reqs = specses.len();
121 let mut reqs = 0;
122 for specs in specses {
123 match rx.try_recv() {
124 Ok(Ok(_)) => reqs += 1,
125 Ok(Err(e)) => return Err(e),
126 Err(TryRecvError::Closed) => {
127 panic!("submit_jobs: all senders have been dropped, this is a bug")
128 }
129 Err(TryRecvError::Empty) => {} }
131
132 if let Ok(Err(e)) = rx.try_recv() {
133 return Err(e);
134 }
135
136 let mut bytes = Vec::new();
137 bytes.push(b'[');
138 for spec in specs {
139 bytes.extend_from_slice(&spec);
140 bytes.push(b',');
141 }
142 if let Some(b) = bytes.last_mut() {
143 *b = b']';
144 }
145
146 debug_assert!(
147 serde_json::from_slice::<serde_json::Value>(&bytes).is_ok(),
148 "bytes are not valid json"
149 );
150
151 let bc = self.bc.clone();
152 let ep = endpoint.clone();
153 let mut tx = tx.clone();
154 tokio::spawn(async move {
155 let resp = bc.post(&*ep, "application/json", bytes).await;
156 let _ = tx.send(resp).await;
157 });
158 }
159
160 std::mem::drop(tx);
161 while let Some(resp) = rx.next().await {
162 if let Err(e) = resp {
163 return Err(e);
164 }
165 reqs += 1;
166 }
167 assert_eq!(reqs, n_reqs, "did not recieve enough responses");
168
169 Ok(())
170 }
171
172 pub async fn submit(self) -> crate::Result<Batch> {
173 #[derive(Deserialize)]
174 struct BatchID {
175 id: u64,
176 }
177
178 let spec = BatchSpec {
179 attributes: &self.attributes,
180 billing_project: self.bc.data.billing_project.as_deref().unwrap(),
181 callback: self.callback.as_ref(),
182 n_jobs: self.jobs.len(),
183 token: util::gen_token(),
184 };
185 let BatchID { id } = self
186 .bc
187 .post_json("/api/v1alpha/batches/create", &spec)
188 .await?
189 .json()
190 .await?;
191
192 const MAX_BUNCH_SIZE: usize = 1024 * 1024;
193 const MAX_BUNCH_JOBS: usize = 1024;
194
195 let serialized_jobs = self.jobs.iter().map(|spec| {
196 (
197 spec.id,
198 serde_json::to_vec(spec).expect("to_vec should not fail"),
199 )
200 });
201
202 let mut bunches = Vec::new();
203 let mut bunch = Vec::new();
204 let mut bunch_bytes = 0;
205 for (id, spec) in serialized_jobs {
206 if spec.len() > MAX_BUNCH_SIZE {
207 return Err(Error::Msg(
208 format!(
209 "job {} too large, was {}B which is greater than the limit of {}B",
210 id,
211 spec.len(),
212 MAX_BUNCH_SIZE
213 )
214 .into(),
215 ));
216 } else if bunch_bytes + spec.len() < MAX_BUNCH_SIZE && bunch.len() < MAX_BUNCH_JOBS {
217 bunch_bytes += spec.len() + 1;
218 bunch.push(spec);
219 } else {
220 bunches.push(bunch);
221 bunch_bytes = spec.len() + 1;
222 bunch = vec![spec];
223 }
224 }
225
226 if !bunch.is_empty() {
227 bunches.push(bunch);
228 }
229
230 self.submit_jobs(id, bunches).await?;
231
232 let path = format!("/api/v1alpha/batches/{}/close", id);
233 self.bc.patch(&path).await?;
234 Ok(Batch {
235 bc: self.bc,
236 id,
237 attributes: self.attributes,
238 jobs: self.jobs,
239 })
240 }
241}
242
243impl Batch {
244 pub fn id(&self) -> u64 {
245 self.id
246 }
247
248 pub fn web_url(&self) -> reqwest::Url {
249 self.bc.join_url(&format!("/batches/{}", self.id))
250 }
251
252 pub async fn cancel(&self) -> crate::Result<()> {
253 let ep = format!("/api/v1alpha/batches/{}/cancel", self.id);
254 self.bc.patch(&ep).await.map(|_| ())
255 }
256}
257
258#[derive(Debug, Serialize, Deserialize)]
259pub struct JobSpec {
260 #[serde(rename = "job_id", default)]
261 id: usize,
262 #[serde(skip_serializing_if = "<&bool as std::ops::Not>::not", default)]
263 always_run: bool,
264 #[serde(skip_serializing_if = "HashMap::is_empty", default)]
265 attributes: HashMap<String, String>,
266 command: Vec<String>,
267 #[serde(skip_serializing_if = "HashMap::is_empty", with = "env_map", default)]
268 env: HashMap<String, String>,
269 #[serde(skip_serializing_if = "Vec::is_empty", default)]
270 gcsfuse: Vec<GcsFuseMount>,
271 image: String,
272 #[serde(skip_serializing_if = "Vec::is_empty", default)]
273 input_files: Vec<FileMapping>,
274 #[serde(default)]
275 mount_docker_socket: bool,
276 #[serde(skip_serializing_if = "Vec::is_empty", default)]
277 output_files: Vec<FileMapping>,
278 #[serde(default)]
279 parent_ids: Vec<usize>,
280 #[serde(skip_serializing_if = "Option::is_none", default)]
281 port: Option<u16>,
282 #[serde(skip_serializing_if = "Option::is_none", default)]
283 requester_pays_project: Option<String>,
284 #[serde(skip_serializing_if = "Option::is_none", default)]
285 network: Option<String>,
286 #[serde(default)]
287 resources: Resources,
288 #[serde(skip_serializing_if = "Vec::is_empty", default)]
289 secrets: Vec<Secret>,
290 #[serde(skip_serializing_if = "Option::is_none", default)]
291 service_account: Option<ServiceAccount>,
292 #[serde(skip_serializing_if = "Option::is_none", default)]
293 timeout: Option<NonZeroU64>,
295}
296
297impl JobSpec {
298 pub fn id(&self) -> usize {
299 self.id
300 }
301}
302
303#[derive(Debug)]
304pub struct JobBuilder<'bb> {
305 bb: &'bb mut BatchBuilder,
306 spec: JobSpec,
307}
308
309impl Deref for JobBuilder<'_> {
310 type Target = JobSpec;
311 fn deref(&self) -> &Self::Target {
312 &self.spec
313 }
314}
315
316impl DerefMut for JobBuilder<'_> {
317 fn deref_mut(&mut self) -> &mut Self::Target {
318 &mut self.spec
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct GcsFuseMount {
324 pub bucket: String,
325 pub mount_path: String,
326 pub read_only: bool,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct FileMapping {
331 pub from: String,
332 pub to: String,
333}
334
335impl FileMapping {
336 fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
337 Self {
338 from: from.into(),
339 to: to.into(),
340 }
341 }
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct ServiceAccount {
346 pub namespace: String,
347 pub name: String,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct Secret {
352 pub namespace: String,
353 pub name: String,
354 pub mount_path: String,
355}
356
357#[derive(Debug, Clone, Copy, Deserialize)]
358pub struct Resources {
359 #[serde(deserialize_with = "deserialize_cpu")]
360 pub cpu: f64,
361 #[serde(deserialize_with = "deserialize_mem")]
362 pub memory: u64,
363 #[serde(deserialize_with = "deserialize_mem")]
364 pub storage: u64,
365}
366
367static MEM_RE_STR: &str = r"[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?";
368static CPU_RE_STR: &str = r"[+]?((?:[0-9]*[.])?[0-9]+)([m])?";
369lazy_static::lazy_static! {
370 static ref MEM_RE: regex::Regex = regex::Regex::new(MEM_RE_STR).unwrap();
371 static ref CPU_RE: regex::Regex = regex::Regex::new(CPU_RE_STR).unwrap();
372}
373
374struct ReVisitor(&'static regex::Regex);
375
376impl<'de> Visitor<'de> for ReVisitor {
377 type Value = (f64, Option<&'de str>);
378
379 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
380 write!(f, "expected value to match regex, {:?}", self.0.as_str())
381 }
382
383 fn visit_borrowed_str<E: DeError>(self, val: &'de str) -> Result<Self::Value, E> {
384 use serde::de;
385 if let Some(groups) = self.0.captures(val) {
386 let v = groups.get(1).unwrap().as_str().parse().unwrap();
389 Ok((v, groups.get(2).map(|m| m.as_str())))
390 } else {
391 Err(E::invalid_value(de::Unexpected::Str(val), &self))
392 }
393 }
394}
395
396fn deserialize_mem<'de, D>(de: D) -> Result<u64, D::Error>
397where
398 D: Deserializer<'de>,
399{
400 let (val, suf) = de.deserialize_str(ReVisitor(&*MEM_RE))?;
401 let mul = match suf.unwrap_or("") {
402 "K" => 1000f64,
403 "Ki" => 1024f64,
404 "M" => 1000f64.powi(2),
405 "Mi" => 1024f64.powi(2),
406 "G" => 1000f64.powi(3),
407 "Gi" => 1024f64.powi(3),
408 "T" => 1000f64.powi(4),
409 "Ti" => 1024f64.powi(4),
410 "P" => 1000f64.powi(5),
411 "Pi" => 1024f64.powi(5),
412 "" => 1.,
413 _ => unreachable!(),
414 };
415
416 Ok((val * mul).ceil() as u64)
417}
418
419fn deserialize_cpu<'de, D>(de: D) -> Result<f64, D::Error>
420where
421 D: Deserializer<'de>,
422{
423 let (val, suf) = de.deserialize_str(ReVisitor(&*CPU_RE))?;
424 if Some("m") == suf {
425 Ok(val / 1000.)
426 } else {
427 Ok(val)
428 }
429}
430
431impl Serialize for Resources {
433 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
434 where
435 S: serde::ser::Serializer,
436 {
437 use serde::ser::SerializeStruct;
438 let mut s = serializer.serialize_struct("Resources", 3)?;
439 s.serialize_field("cpu", &self.cpu.to_string())?;
440 s.serialize_field("memory", &self.memory.to_string())?;
441 s.serialize_field("storage", &self.storage.to_string())?;
442 s.end()
443 }
444}
445
446impl Resources {
447 const DEFAULT: Self = Resources {
448 cpu: 1.,
449 memory: (375 * 1024 * 1024 * 1024) / 100, storage: 10 * 1024 * 1024 * 1024, };
452}
453
454impl Default for Resources {
455 fn default() -> Self {
456 Self::DEFAULT
457 }
458}
459
460impl<'bb> JobBuilder<'bb> {
461 fn new(bb: &'bb mut BatchBuilder, image: String, cmd: String) -> Self {
462 let id = bb.jobs.len() + 1;
463 Self {
464 bb,
465 spec: JobSpec {
466 id,
467 image,
468 command: vec![cmd],
469 always_run: false,
470 attributes: HashMap::new(),
471 env: HashMap::new(),
472 gcsfuse: Vec::new(),
473 input_files: Vec::new(),
474 mount_docker_socket: false,
475 output_files: Vec::new(),
476 parent_ids: Vec::new(),
477 port: None,
478 requester_pays_project: None,
479 network: None,
480 resources: Resources::DEFAULT,
481 secrets: Vec::new(),
482 service_account: None,
483 timeout: None,
484 },
485 }
486 }
487 pub fn name(&mut self, name: impl Into<String>) -> &mut Self {
488 self.attribute("name", name.into())
489 }
490
491 pub fn attribute(&mut self, key: impl ToString, value: impl ToString) -> &mut Self {
492 self.attributes.insert(key.to_string(), value.to_string());
493 self
494 }
495
496 pub fn attributes<I, S, T>(&mut self, attrs: I) -> &mut Self
497 where
498 S: ToString,
499 T: ToString,
500 I: IntoIterator<Item = (S, T)>,
501 {
502 let attrs = attrs
503 .into_iter()
504 .map(|(k, v)| (k.to_string(), v.to_string()));
505 self.attributes.extend(attrs);
506 self
507 }
508
509 pub fn always_run(&mut self, always_run: bool) -> &mut Self {
511 self.always_run = always_run;
512 self
513 }
514
515 pub fn arg(&mut self, arg: impl Into<String>) -> &mut Self {
517 self.command.push(arg.into());
518 self
519 }
520
521 pub fn args<I, S>(&mut self, args: I) -> &mut Self
523 where
524 S: Into<String>,
525 I: IntoIterator<Item = S>,
526 {
527 self.command.extend(args.into_iter().map(S::into));
528 self
529 }
530
531 pub fn args_mut(&mut self) -> &mut Vec<String> {
533 &mut self.command
534 }
535
536 pub fn env(&mut self, key: impl Into<String>, val: impl Into<String>) -> &mut Self {
538 self.env.insert(key.into(), val.into());
539 self
540 }
541
542 pub fn env_vars<I, S, T>(&mut self, vars: I) -> &mut Self
544 where
545 S: Into<String>,
546 T: Into<String>,
547 I: IntoIterator<Item = (S, T)>,
548 {
549 self.env
550 .extend(vars.into_iter().map(|(k, v)| (k.into(), v.into())));
551 self
552 }
553
554 pub fn env_remove(&mut self, key: &impl std::borrow::Borrow<str>) -> Option<String> {
556 self.env.remove(key.borrow())
557 }
558
559 pub fn env_clear(&mut self) {
562 self.env.clear();
563 }
564
565 pub fn gcsfuse(
567 &mut self,
568 bucket: impl Into<String>,
569 mount_path: impl Into<String>,
570 read_only: bool,
571 ) -> &mut Self {
572 let gcsfuse_mount = GcsFuseMount {
573 bucket: bucket.into(),
574 mount_path: mount_path.into(),
575 read_only,
576 };
577 self.gcsfuse.push(gcsfuse_mount);
578 self
579 }
580
581 pub fn input_file(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
583 self.input_files.push(FileMapping::new(from, to));
584 self
585 }
586
587 pub fn input_files<I, S, T>(&mut self, paths: I) -> &mut Self
589 where
590 S: Into<String>,
591 T: Into<String>,
592 I: IntoIterator<Item = (S, T)>,
593 {
594 self.input_files
595 .extend(paths.into_iter().map(|(f, t)| FileMapping::new(f, t)));
596 self
597 }
598
599 pub fn mount_docker_socket(&mut self, mount: bool) -> &mut Self {
601 self.mount_docker_socket = mount;
602 self
603 }
604
605 pub fn output_file(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
607 self.output_files.push(FileMapping::new(from, to));
608 self
609 }
610
611 pub fn output_files<I, S, T>(&mut self, paths: I) -> &mut Self
613 where
614 S: Into<String>,
615 T: Into<String>,
616 I: IntoIterator<Item = (S, T)>,
617 {
618 self.output_files
619 .extend(paths.into_iter().map(|(f, t)| FileMapping::new(f, t)));
620 self
621 }
622
623 pub fn parent(&mut self, parent_id: usize) -> &mut Self {
630 assert!(parent_id < self.id, "invalid parent_id: {}", parent_id);
631 self.parent_ids.push(parent_id);
632 self
633 }
634
635 pub fn parents(&mut self, parent_ids: impl IntoIterator<Item = usize>) -> &mut Self {
642 let start = self.parent_ids.len();
643 self.parent_ids.extend(parent_ids);
644 if self.parent_ids[start..].iter().any(|&id| id >= self.id) {
645 let invalids = self.parent_ids[start..]
646 .iter()
647 .filter(|&&id| id >= self.id)
648 .collect::<Vec<_>>();
649 panic!("invalid parent ids: {:?}", invalids);
650 }
651 self
652 }
653
654 pub fn port(&mut self, port: u16) -> &mut Self {
656 self.port.replace(port);
657 self
658 }
659
660 pub fn clear_port(&mut self) -> &mut Self {
662 self.port.take();
663 self
664 }
665
666 pub fn requester_pays_project(&mut self, project: impl Into<String>) -> &mut Self {
668 self.requester_pays_project.replace(project.into());
669 self
670 }
671
672 pub fn network(&mut self, network: impl Into<String>) -> &mut Self {
674 self.network.replace(network.into());
675 self
676 }
677
678 pub fn resources(&mut self, cpu: f64, memory: u64, storage: u64) -> &mut Self {
682 self.resources = Resources {
683 cpu,
684 memory,
685 storage,
686 };
687 self
688 }
689
690 pub fn cpu(&mut self, cpu: f64) -> &mut Self {
692 self.resources.cpu = cpu;
693 self
694 }
695
696 pub fn memory(&mut self, memory: u64) -> &mut Self {
698 self.resources.memory = memory;
699 self
700 }
701
702 pub fn storage(&mut self, storage: u64) -> &mut Self {
704 self.resources.storage = storage;
705 self
706 }
707
708 pub fn service_account(
710 &mut self,
711 namespace: impl Into<String>,
712 name: impl Into<String>,
713 ) -> &mut Self {
714 self.service_account.replace(ServiceAccount {
715 namespace: namespace.into(),
716 name: name.into(),
717 });
718 self
719 }
720
721 pub fn timeout(&mut self, timeout: u64) -> &mut Self {
723 self.timeout = NonZeroU64::new(timeout);
724 self
725 }
726
727 pub fn build(self) -> usize {
729 let JobBuilder { spec, bb } = self;
730 let id = spec.id;
731 bb.jobs.push(spec);
732 assert_eq!(id, bb.jobs.len(), "mismatch in job count and job id");
733 id
734 }
735}
736
737mod env_map {
738 use super::*;
739
740 #[derive(Serialize, Deserialize)]
741 struct EnvMapping<'a> {
742 name: &'a str,
743 value: &'a str,
744 }
745
746 impl<'a> From<(&'a str, &'a str)> for EnvMapping<'a> {
747 fn from((name, value): (&'a str, &'a str)) -> Self {
748 Self { name, value }
749 }
750 }
751
752 impl<'a, S1: AsRef<str>, S2: AsRef<str>> From<(&'a S1, &'a S2)> for EnvMapping<'a> {
753 fn from((name, value): (&'a S1, &'a S2)) -> Self {
754 Self {
755 name: name.as_ref(),
756 value: value.as_ref(),
757 }
758 }
759 }
760
761 pub fn deserialize<'de, D>(de: D) -> Result<HashMap<String, String>, D::Error>
762 where
763 D: Deserializer<'de>,
764 {
765 use serde::de::SeqAccess;
766 use std::fmt;
767 struct EnvMapVisitor;
768
769 impl<'de> Visitor<'de> for EnvMapVisitor {
770 type Value = HashMap<String, String>;
771
772 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
773 write!(formatter, "a sequence of name/value pairs")
774 }
775
776 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
777 let mut map = seq
778 .size_hint()
779 .map_or_else(HashMap::new, HashMap::with_capacity);
780 while let Some(EnvMapping { name, value }) = seq.next_element()? {
781 map.insert(name.into(), value.into());
782 }
783 Ok(map)
784 }
785 }
786
787 de.deserialize_seq(EnvMapVisitor)
788 }
789
790 pub fn serialize<S>(env: &HashMap<String, String>, ser: S) -> Result<S::Ok, S::Error>
791 where
792 S: Serializer,
793 {
794 use serde::ser::SerializeSeq;
795 let len = env.len();
796 let mut seq = ser.serialize_seq(Some(len))?;
797 for map in env.iter().map(EnvMapping::from) {
798 seq.serialize_element(&map)?;
799 }
800 seq.end()
801 }
802}