1use cargo_options::Run;
2use clap::Args;
3use matchit::{InsertError, MatchError, Router};
4use serde::{
5 Deserialize, Serialize,
6 de::{Error, Visitor},
7 ser::SerializeSeq,
8};
9use serde_json::{Value, json};
10use std::{collections::HashMap, path::PathBuf};
11
12use crate::{
13 cargo::{count_common_options, serialize_common_options},
14 env::{EnvOptions, Environment},
15 error::MetadataError,
16 lambda::Timeout,
17};
18
19use cargo_lambda_remote::tls::TlsOptions;
20
21#[cfg(windows)]
22const DEFAULT_INVOKE_ADDRESS: &str = "127.0.0.1";
23
24#[cfg(not(windows))]
25const DEFAULT_INVOKE_ADDRESS: &str = "::";
26
27const DEFAULT_INVOKE_PORT: u16 = 9000;
28
29#[derive(Args, Clone, Debug, Default, Deserialize)]
30#[command(
31 name = "watch",
32 visible_alias = "start",
33 after_help = "Full command documentation: https://www.cargo-lambda.info/commands/watch.html"
34)]
35pub struct Watch {
36 #[arg(long, visible_alias = "no-reload")]
38 #[serde(default)]
39 pub ignore_changes: bool,
40
41 #[arg(long)]
44 #[serde(default)]
45 pub only_lambda_apis: bool,
46
47 #[arg(short = 'A', long, default_value = DEFAULT_INVOKE_ADDRESS)]
48 #[serde(default = "default_invoke_address")]
49 pub invoke_address: String,
51
52 #[arg(short = 'P', long, default_value_t = DEFAULT_INVOKE_PORT)]
54 #[serde(default = "default_invoke_port")]
55 pub invoke_port: u16,
56
57 #[arg(long)]
59 #[serde(default)]
60 pub print_traces: bool,
61
62 #[arg(long, short)]
64 #[serde(default)]
65 pub wait: bool,
66
67 #[arg(long)]
69 #[serde(default)]
70 pub disable_cors: bool,
71
72 #[arg(long)]
74 #[serde(default)]
75 pub timeout: Option<Timeout>,
76
77 #[command(flatten)]
78 #[serde(flatten)]
79 pub cargo_opts: Run,
80
81 #[command(flatten)]
82 #[serde(flatten)]
83 pub env_options: EnvOptions,
84
85 #[command(flatten)]
86 #[serde(flatten)]
87 pub tls_options: TlsOptions,
88
89 #[arg(skip)]
90 #[serde(default, skip_serializing_if = "is_empty_router")]
91 pub router: Option<FunctionRouter>,
92}
93
94impl Watch {
95 pub fn manifest_path(&self) -> PathBuf {
96 self.cargo_opts
97 .manifest_path
98 .clone()
99 .unwrap_or_else(|| "Cargo.toml".into())
100 }
101
102 pub fn pkg_name(&self) -> Option<String> {
105 if self.cargo_opts.packages.len() > 1 {
106 return None;
107 }
108 self.cargo_opts.packages.first().map(|s| s.to_string())
109 }
110
111 pub fn bin_name(&self) -> Option<String> {
112 if self.cargo_opts.bin.len() > 1 {
113 return None;
114 }
115 self.cargo_opts.bin.first().map(|s| s.to_string())
116 }
117
118 pub fn lambda_environment(
119 &self,
120 base: &HashMap<String, String>,
121 ) -> Result<Environment, MetadataError> {
122 self.env_options.lambda_environment(base)
123 }
124}
125
126impl Serialize for Watch {
127 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: serde::Serializer,
130 {
131 use serde::ser::SerializeStruct;
132
133 let field_count = self.ignore_changes as usize
135 + self.only_lambda_apis as usize
136 + !self.invoke_address.is_empty() as usize
137 + (self.invoke_port != 0) as usize
138 + self.print_traces as usize
139 + self.wait as usize
140 + self.disable_cors as usize
141 + self.timeout.is_some() as usize
142 + self.router.is_some() as usize
143 + self.cargo_opts.manifest_path.is_some() as usize
144 + self.cargo_opts.release as usize
145 + self.cargo_opts.ignore_rust_version as usize
146 + self.cargo_opts.unit_graph as usize
147 + !self.cargo_opts.packages.is_empty() as usize
148 + !self.cargo_opts.bin.is_empty() as usize
149 + !self.cargo_opts.example.is_empty() as usize
150 + !self.cargo_opts.args.is_empty() as usize
151 + count_common_options(&self.cargo_opts.common)
152 + self.env_options.count_fields()
153 + self.tls_options.count_fields();
154
155 let mut state = serializer.serialize_struct("Watch", field_count)?;
156
157 if self.ignore_changes {
159 state.serialize_field("ignore_changes", &true)?;
160 }
161 if self.only_lambda_apis {
162 state.serialize_field("only_lambda_apis", &true)?;
163 }
164 if !self.invoke_address.is_empty() {
165 state.serialize_field("invoke_address", &self.invoke_address)?;
166 }
167 if self.invoke_port != 0 {
168 state.serialize_field("invoke_port", &self.invoke_port)?;
169 }
170 if self.print_traces {
171 state.serialize_field("print_traces", &true)?;
172 }
173 if self.wait {
174 state.serialize_field("wait", &true)?;
175 }
176 if self.disable_cors {
177 state.serialize_field("disable_cors", &true)?;
178 }
179
180 if let Some(timeout) = &self.timeout {
182 state.serialize_field("timeout", timeout)?;
183 }
184 if let Some(router) = &self.router {
185 state.serialize_field("router", router)?;
186 }
187
188 self.env_options.serialize_fields::<S>(&mut state)?;
190 self.tls_options.serialize_fields::<S>(&mut state)?;
191
192 if let Some(manifest_path) = &self.cargo_opts.manifest_path {
193 state.serialize_field("manifest_path", manifest_path)?;
194 }
195 if self.cargo_opts.release {
196 state.serialize_field("release", &true)?;
197 }
198 if self.cargo_opts.ignore_rust_version {
199 state.serialize_field("ignore_rust_version", &true)?;
200 }
201 if self.cargo_opts.unit_graph {
202 state.serialize_field("unit_graph", &true)?;
203 }
204 if !self.cargo_opts.packages.is_empty() {
205 state.serialize_field("packages", &self.cargo_opts.packages)?;
206 }
207 if !self.cargo_opts.bin.is_empty() {
208 state.serialize_field("bin", &self.cargo_opts.bin)?;
209 }
210 if !self.cargo_opts.example.is_empty() {
211 state.serialize_field("example", &self.cargo_opts.example)?;
212 }
213 if !self.cargo_opts.args.is_empty() {
214 state.serialize_field("args", &self.cargo_opts.args)?;
215 }
216 serialize_common_options::<S>(&mut state, &self.cargo_opts.common)?;
217
218 state.end()
219 }
220}
221
222fn default_invoke_address() -> String {
223 DEFAULT_INVOKE_ADDRESS.to_string()
224}
225
226fn default_invoke_port() -> u16 {
227 DEFAULT_INVOKE_PORT
228}
229
230#[derive(Clone, Debug, Default, Deserialize, Serialize)]
231pub struct WatchConfig {
232 pub router: Option<FunctionRouter>,
233}
234
235#[derive(Clone, Debug, Default)]
236pub struct FunctionRouter {
237 inner: Router<FunctionRoutes>,
238 pub(crate) raw: Vec<Route>,
239}
240
241impl FunctionRouter {
242 pub fn at(
243 &self,
244 path: &str,
245 method: &str,
246 ) -> Result<(String, HashMap<String, String>), MatchError> {
247 let matched = self.inner.at(path)?;
248 let function = matched.value.at(method).ok_or(MatchError::NotFound)?;
249
250 let params = matched
251 .params
252 .iter()
253 .map(|(k, v)| (k.to_string(), v.to_string()))
254 .collect();
255
256 Ok((function.to_string(), params))
257 }
258
259 pub fn insert(&mut self, path: &str, routes: FunctionRoutes) -> Result<(), InsertError> {
260 self.inner.insert(path, routes)
261 }
262
263 pub fn is_empty(&self) -> bool {
264 self.raw.is_empty()
265 }
266}
267
268#[allow(dead_code)]
269fn is_empty_router(router: &Option<FunctionRouter>) -> bool {
270 router.is_none() || router.as_ref().is_some_and(|r| r.is_empty())
271}
272
273#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
274pub struct Route {
275 path: String,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 methods: Option<Vec<String>>,
278 function: String,
279}
280
281#[derive(Clone, Debug, PartialEq)]
282pub enum FunctionRoutes {
283 Single(String),
284 Multiple(HashMap<String, String>),
285}
286
287impl FunctionRoutes {
288 pub fn at(&self, method: &str) -> Option<&str> {
289 match self {
290 FunctionRoutes::Single(function) => Some(function),
291 FunctionRoutes::Multiple(routes) => routes.get(method).map(|s| s.as_str()),
292 }
293 }
294}
295
296struct FunctionRouterVisitor;
297
298impl<'de> Visitor<'de> for FunctionRouterVisitor {
299 type Value = FunctionRouter;
300
301 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
302 formatter.write_str("a map or sequence of function routes")
303 }
304
305 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
306 where
307 A: serde::de::MapAccess<'de>,
308 {
309 let routes: HashMap<String, FunctionRoutes> =
310 Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
311
312 let mut inner = Router::new();
313 let mut raw = Vec::new();
314
315 let mut inverse = HashMap::new();
316
317 for (path, route) in &routes {
318 inner.insert(path, route.clone()).map_err(|e| {
319 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
320 })?;
321
322 match route {
323 FunctionRoutes::Single(function) => {
324 raw.push(Route {
325 path: path.clone(),
326 methods: None,
327 function: function.clone(),
328 });
329 }
330 FunctionRoutes::Multiple(routes) => {
331 for (method, function) in routes {
332 inverse
333 .entry((path.clone(), function.clone()))
334 .and_modify(|route: &mut Route| {
335 let mut methods = route.methods.clone().unwrap_or_default();
336 methods.push(method.clone());
337 route.methods = Some(methods);
338 })
339 .or_insert_with(|| Route {
340 path: path.clone(),
341 methods: Some(vec![method.clone()]),
342 function: function.clone(),
343 });
344 }
345 }
346 }
347 }
348
349 for (_, route) in inverse {
350 raw.push(route);
351 }
352
353 Ok(FunctionRouter { inner, raw })
354 }
355
356 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
357 where
358 A: serde::de::SeqAccess<'de>,
359 {
360 let routes: Vec<Route> =
361 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
362
363 let mut inner = Router::new();
364 let mut raw = Vec::new();
365
366 let mut routes_by_path = HashMap::new();
367
368 for route in &routes {
369 routes_by_path
370 .entry(route.path.clone())
371 .and_modify(|routes| merge_routes(routes, route))
372 .or_insert_with(|| decode_route(route));
373
374 raw.push(route.clone());
375 }
376
377 for (path, route) in &routes_by_path {
378 inner.insert(path, route.clone()).map_err(|e| {
379 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
380 })?;
381 }
382
383 Ok(FunctionRouter { inner, raw })
384 }
385}
386
387fn merge_routes(routes: &mut FunctionRoutes, route: &Route) {
388 let methods = route.methods.clone().unwrap_or_default();
389 match routes {
390 FunctionRoutes::Single(function) if !methods.is_empty() => {
391 let mut tmp = HashMap::new();
392 for method in methods {
393 tmp.insert(method.clone(), function.clone());
394 }
395 *routes = FunctionRoutes::Multiple(tmp);
396 }
397 FunctionRoutes::Multiple(_) if methods.is_empty() => {
398 *routes = FunctionRoutes::Single(route.function.clone());
399 }
400 FunctionRoutes::Multiple(routes) => {
401 for method in methods {
402 routes.insert(method.clone(), route.function.clone());
403 }
404 }
405 FunctionRoutes::Single(_) => {
406 *routes = FunctionRoutes::Single(route.function.clone());
407 }
408 }
409}
410
411fn decode_route(route: &Route) -> FunctionRoutes {
412 match &route.methods {
413 Some(methods) => {
414 let mut routes = HashMap::new();
415 for method in methods {
416 routes.insert(method.clone(), route.function.clone());
417 }
418 FunctionRoutes::Multiple(routes)
419 }
420 None => FunctionRoutes::Single(route.function.clone()),
421 }
422}
423
424impl<'de> Deserialize<'de> for FunctionRouter {
425 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
426 where
427 D: serde::Deserializer<'de>,
428 {
429 deserializer.deserialize_any(FunctionRouterVisitor)
430 }
431}
432
433impl Serialize for FunctionRouter {
434 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
435 where
436 S: serde::Serializer,
437 {
438 self.raw.serialize(serializer)
439 }
440}
441
442impl<'de> Deserialize<'de> for FunctionRoutes {
443 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
444 where
445 D: serde::Deserializer<'de>,
446 {
447 let value = Value::deserialize(deserializer)?;
448 match value {
449 Value::String(s) => Ok(FunctionRoutes::Single(s)),
450 Value::Array(arr) => {
451 let mut routes = HashMap::new();
452 for item in arr {
453 let obj = item.as_object().ok_or_else(|| {
454 Error::custom("Array items must be objects with method and function fields")
455 })?;
456
457 let method = obj
458 .get("method")
459 .and_then(|m| m.as_str())
460 .ok_or_else(|| Error::custom("Missing or invalid method field"))?;
461
462 let function = obj
463 .get("function")
464 .and_then(|f| f.as_str())
465 .ok_or_else(|| Error::custom("Missing or invalid function field"))?;
466
467 routes.insert(method.to_string(), function.to_string());
468 }
469 Ok(FunctionRoutes::Multiple(routes))
470 }
471 _ => Err(Error::custom(
472 "Function routes must be either a string or an array of objects with method and function fields",
473 )),
474 }
475 }
476}
477
478impl Serialize for FunctionRoutes {
479 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
480 where
481 S: serde::Serializer,
482 {
483 match self {
484 FunctionRoutes::Single(function) => function.serialize(serializer),
485 FunctionRoutes::Multiple(routes) => {
486 let mut seq = serializer.serialize_seq(Some(routes.len()))?;
487 for (method, function) in routes {
488 let mut map = serde_json::Map::new();
489 map.insert("method".to_string(), json!(method));
490 map.insert("function".to_string(), json!(function));
491 seq.serialize_element(&Value::Object(map))?;
492 }
493 seq.end()
494 }
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501
502 use cargo_options::CommonOptions;
503 use serde_json::{Value, json};
504 use std::path::PathBuf;
505
506 use super::*;
507
508 #[test]
509 fn test_router_deserialize() {
510 let router: FunctionRouter = toml::from_str(
511 r#"
512 "/api/v1/users" = [
513 { function = "get_user", method = "GET" },
514 { function = "create_user", method = "POST" }
515 ]
516 "/api/v1/all_methods" = "all_methods"
517 "#,
518 )
519 .unwrap();
520
521 assert_eq!(
522 router.inner.at("/api/v1/users").unwrap().value,
523 &FunctionRoutes::Multiple(HashMap::from([
524 ("GET".to_string(), "get_user".to_string()),
525 ("POST".to_string(), "create_user".to_string()),
526 ]))
527 );
528
529 assert_eq!(
530 router.inner.at("/api/v1/all_methods").unwrap().value,
531 &FunctionRoutes::Single("all_methods".to_string())
532 );
533 }
534
535 #[test]
536 fn test_router_get() {
537 let router = FunctionRouter::default();
538 assert_eq!(router.at("/api/v1/users", "GET"), Err(MatchError::NotFound));
539
540 let mut inner = Router::new();
541 inner
542 .insert(
543 "/api/v1/users",
544 FunctionRoutes::Single("user_handler".to_string()),
545 )
546 .unwrap();
547 let router = FunctionRouter {
548 inner,
549 ..Default::default()
550 };
551 assert_eq!(
552 router.at("/api/v1/users", "GET"),
553 Ok(("user_handler".to_string(), HashMap::new()))
554 );
555 assert_eq!(
556 router.at("/api/v1/users", "POST"),
557 Ok(("user_handler".to_string(), HashMap::new()))
558 );
559
560 let mut inner = Router::new();
561 inner
562 .insert(
563 "/api/v1/users",
564 FunctionRoutes::Multiple(HashMap::from([
565 ("GET".to_string(), "get_user".to_string()),
566 ("POST".to_string(), "create_user".to_string()),
567 ])),
568 )
569 .unwrap();
570 let router = FunctionRouter {
571 inner,
572 ..Default::default()
573 };
574 assert_eq!(
575 router.at("/api/v1/users", "GET"),
576 Ok(("get_user".to_string(), HashMap::new()))
577 );
578 assert_eq!(
579 router.at("/api/v1/users", "POST"),
580 Ok(("create_user".to_string(), HashMap::new()))
581 );
582 assert_eq!(router.at("/api/v1/users", "PUT"), Err(MatchError::NotFound));
583
584 let mut inner = Router::new();
585 inner
586 .insert(
587 "/api/v1/users/{id}",
588 FunctionRoutes::Single("user_handler".to_string()),
589 )
590 .unwrap();
591 let router = FunctionRouter {
592 inner,
593 ..Default::default()
594 };
595
596 let (function, params) = router.at("/api/v1/users/1", "GET").unwrap();
597 assert_eq!(function, "user_handler");
598 assert_eq!(params, HashMap::from([("id".to_string(), "1".to_string())]));
599 }
600
601 #[test]
602 fn test_router_serialize() {
603 let config = r#"
604 "/api/v1/users" = [
605 { function = "get_user", method = "GET" },
606 { function = "create_user", method = "POST" }
607 ]
608 "/api/v1/all_methods" = "all_methods"
609 "#;
610 let router: FunctionRouter = toml::from_str(config).unwrap();
611
612 let json = serde_json::to_value(&router).unwrap();
613
614 let new_router: FunctionRouter = serde_json::from_value(json).unwrap();
615 assert_eq!(new_router.raw, router.raw);
616
617 assert_eq!(
618 new_router.inner.at("/api/v1/users").unwrap().value,
619 &FunctionRoutes::Multiple(HashMap::from([
620 ("GET".to_string(), "get_user".to_string()),
621 ("POST".to_string(), "create_user".to_string()),
622 ]))
623 );
624
625 assert_eq!(
626 new_router.inner.at("/api/v1/all_methods").unwrap().value,
627 &FunctionRoutes::Single("all_methods".to_string())
628 );
629 }
630
631 #[test]
632 fn test_watch_serialization() {
633 let watch = Watch {
634 invoke_address: "127.0.0.1".to_string(),
635 invoke_port: 9000,
636 env_options: EnvOptions {
637 env_file: Some(PathBuf::from("/tmp/env")),
638 env_var: Some(vec!["FOO=BAR".to_string()]),
639 },
640 tls_options: TlsOptions::new(
641 Some(PathBuf::from("/tmp/cert.pem")),
642 Some(PathBuf::from("/tmp/key.pem")),
643 Some(PathBuf::from("/tmp/ca.pem")),
644 ),
645 cargo_opts: Run {
646 common: CommonOptions {
647 quiet: false,
648 jobs: None,
649 keep_going: false,
650 profile: None,
651 features: vec!["feature1".to_string()],
652 all_features: false,
653 no_default_features: true,
654 target: vec!["x86_64-unknown-linux-gnu".to_string()],
655 target_dir: Some(PathBuf::from("/tmp/target")),
656 message_format: vec!["json".to_string()],
657 verbose: 1,
658 color: Some("auto".to_string()),
659 frozen: true,
660 locked: true,
661 offline: true,
662 config: vec!["config.toml".to_string()],
663 unstable_flags: vec!["flag1".to_string()],
664 timings: None,
665 },
666 manifest_path: None,
667 release: false,
668 ignore_rust_version: false,
669 unit_graph: false,
670 packages: vec![],
671 bin: vec![],
672 example: vec![],
673 args: vec![],
674 },
675 ..Default::default()
676 };
677
678 let json = serde_json::to_value(&watch).unwrap();
679 assert_eq!(json["invoke_address"], "127.0.0.1");
680 assert_eq!(json["invoke_port"], 9000);
681 assert_eq!(json["env_file"], "/tmp/env");
682 assert_eq!(json["env_var"], json!(["FOO=BAR"]));
683 assert_eq!(json["tls_cert"], "/tmp/cert.pem");
684 assert_eq!(json["tls_key"], "/tmp/key.pem");
685 assert_eq!(json["tls_ca"], "/tmp/ca.pem");
686 assert_eq!(json["features"], json!(["feature1"]));
687 assert_eq!(json["no_default_features"], true);
688 assert_eq!(json["target"], json!(["x86_64-unknown-linux-gnu"]));
689 assert_eq!(json["target_dir"], "/tmp/target");
690 assert_eq!(json["message_format"], json!(["json"]));
691 assert_eq!(json["verbose"], 1);
692 assert_eq!(json["color"], "auto");
693 assert_eq!(json["frozen"], true);
694 assert_eq!(json["locked"], true);
695 assert_eq!(json["offline"], true);
696 assert_eq!(json["config"], json!(["config.toml"]));
697 assert_eq!(json["unstable_flags"], json!(["flag1"]));
698 assert_eq!(json["timings"], Value::Null);
699
700 let deserialized: Watch = serde_json::from_value(json).unwrap();
701
702 assert_eq!(deserialized.invoke_address, watch.invoke_address);
703 assert_eq!(deserialized.invoke_port, watch.invoke_port);
704 assert_eq!(
705 deserialized.env_options.env_file,
706 watch.env_options.env_file
707 );
708 assert_eq!(deserialized.env_options.env_var, watch.env_options.env_var);
709 assert_eq!(
710 deserialized.tls_options.tls_cert,
711 watch.tls_options.tls_cert
712 );
713 assert_eq!(deserialized.tls_options.tls_key, watch.tls_options.tls_key);
714 assert_eq!(deserialized.tls_options.tls_ca, watch.tls_options.tls_ca);
715 assert_eq!(
716 deserialized.cargo_opts.common.features,
717 watch.cargo_opts.common.features
718 );
719 assert_eq!(
720 deserialized.cargo_opts.common.no_default_features,
721 watch.cargo_opts.common.no_default_features
722 );
723 assert_eq!(
724 deserialized.cargo_opts.common.target,
725 watch.cargo_opts.common.target
726 );
727 assert_eq!(
728 deserialized.cargo_opts.common.target_dir,
729 watch.cargo_opts.common.target_dir
730 );
731 assert_eq!(
732 deserialized.cargo_opts.common.message_format,
733 watch.cargo_opts.common.message_format
734 );
735 assert_eq!(
736 deserialized.cargo_opts.common.verbose,
737 watch.cargo_opts.common.verbose
738 );
739 assert_eq!(
740 deserialized.cargo_opts.common.color,
741 watch.cargo_opts.common.color
742 );
743 assert_eq!(
744 deserialized.cargo_opts.common.frozen,
745 watch.cargo_opts.common.frozen
746 );
747 assert_eq!(
748 deserialized.cargo_opts.common.locked,
749 watch.cargo_opts.common.locked
750 );
751 assert_eq!(
752 deserialized.cargo_opts.common.offline,
753 watch.cargo_opts.common.offline
754 );
755 assert_eq!(
756 deserialized.cargo_opts.common.config,
757 watch.cargo_opts.common.config
758 );
759 assert_eq!(
760 deserialized.cargo_opts.common.unstable_flags,
761 watch.cargo_opts.common.unstable_flags
762 );
763 assert_eq!(
764 deserialized.cargo_opts.common.timings,
765 watch.cargo_opts.common.timings
766 );
767 }
768}