1use cargo_options::Run;
2use clap::Args;
3use matchit::{InsertError, MatchError, Router};
4use serde::{
5 Deserialize, Serialize,
6 de::{Error, Visitor},
7 ser::SerializeSeq,
8};
9use serde_json::{Value, json};
10use std::{collections::HashMap, path::PathBuf};
11
12use crate::{
13 cargo::{count_common_options, serialize_common_options},
14 env::{EnvOptions, Environment},
15 error::MetadataError,
16 lambda::Timeout,
17};
18
19use cargo_lambda_remote::tls::TlsOptions;
20
21#[cfg(windows)]
22const DEFAULT_INVOKE_ADDRESS: &str = "127.0.0.1";
23
24#[cfg(not(windows))]
25const DEFAULT_INVOKE_ADDRESS: &str = "::";
26
27const DEFAULT_INVOKE_PORT: u16 = 9000;
28
29#[derive(Args, Clone, Debug, Default, Deserialize)]
30#[command(
31 name = "watch",
32 visible_alias = "start",
33 after_help = "Full command documentation: https://www.cargo-lambda.info/commands/watch.html"
34)]
35pub struct Watch {
36 #[arg(long, visible_alias = "no-reload")]
38 #[serde(default)]
39 pub ignore_changes: bool,
40
41 #[arg(long)]
44 #[serde(default)]
45 pub only_lambda_apis: bool,
46
47 #[arg(short = 'a', long, default_value = DEFAULT_INVOKE_ADDRESS)]
48 #[serde(default = "default_invoke_address")]
49 pub invoke_address: String,
51
52 #[arg(short = 'P', long, default_value_t = DEFAULT_INVOKE_PORT)]
54 #[serde(default = "default_invoke_port")]
55 pub invoke_port: u16,
56
57 #[arg(long)]
59 #[serde(default)]
60 pub print_traces: bool,
61
62 #[arg(long, short)]
64 #[serde(default)]
65 pub wait: bool,
66
67 #[arg(long)]
69 #[serde(default)]
70 pub disable_cors: bool,
71
72 #[arg(long)]
74 #[serde(default)]
75 pub timeout: Option<Timeout>,
76
77 #[command(flatten)]
78 #[serde(flatten)]
79 pub cargo_opts: Run,
80
81 #[command(flatten)]
82 #[serde(flatten)]
83 pub env_options: EnvOptions,
84
85 #[command(flatten)]
86 #[serde(flatten)]
87 pub tls_options: TlsOptions,
88
89 #[arg(skip)]
90 #[serde(default, skip_serializing_if = "is_empty_router")]
91 pub router: Option<FunctionRouter>,
92}
93
94impl Watch {
95 pub fn manifest_path(&self) -> PathBuf {
96 self.cargo_opts
97 .manifest_path
98 .clone()
99 .unwrap_or_else(|| "Cargo.toml".into())
100 }
101
102 pub fn package(&self) -> Option<String> {
105 if self.cargo_opts.packages.len() > 1 {
106 return None;
107 }
108 self.cargo_opts.packages.first().map(|s| s.to_string())
109 }
110
111 pub fn lambda_environment(
112 &self,
113 base: &HashMap<String, String>,
114 ) -> Result<Environment, MetadataError> {
115 self.env_options.lambda_environment(base)
116 }
117}
118
119impl Serialize for Watch {
120 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121 where
122 S: serde::Serializer,
123 {
124 use serde::ser::SerializeStruct;
125
126 let field_count = self.ignore_changes as usize
128 + self.only_lambda_apis as usize
129 + !self.invoke_address.is_empty() as usize
130 + (self.invoke_port != 0) as usize
131 + self.print_traces as usize
132 + self.wait as usize
133 + self.disable_cors as usize
134 + self.timeout.is_some() as usize
135 + self.router.is_some() as usize
136 + self.cargo_opts.manifest_path.is_some() as usize
137 + self.cargo_opts.release as usize
138 + self.cargo_opts.ignore_rust_version as usize
139 + self.cargo_opts.unit_graph as usize
140 + !self.cargo_opts.packages.is_empty() as usize
141 + !self.cargo_opts.bin.is_empty() as usize
142 + !self.cargo_opts.example.is_empty() as usize
143 + !self.cargo_opts.args.is_empty() as usize
144 + count_common_options(&self.cargo_opts.common)
145 + self.env_options.count_fields()
146 + self.tls_options.count_fields();
147
148 let mut state = serializer.serialize_struct("Watch", field_count)?;
149
150 if self.ignore_changes {
152 state.serialize_field("ignore_changes", &true)?;
153 }
154 if self.only_lambda_apis {
155 state.serialize_field("only_lambda_apis", &true)?;
156 }
157 if !self.invoke_address.is_empty() {
158 state.serialize_field("invoke_address", &self.invoke_address)?;
159 }
160 if self.invoke_port != 0 {
161 state.serialize_field("invoke_port", &self.invoke_port)?;
162 }
163 if self.print_traces {
164 state.serialize_field("print_traces", &true)?;
165 }
166 if self.wait {
167 state.serialize_field("wait", &true)?;
168 }
169 if self.disable_cors {
170 state.serialize_field("disable_cors", &true)?;
171 }
172
173 if let Some(timeout) = &self.timeout {
175 state.serialize_field("timeout", timeout)?;
176 }
177 if let Some(router) = &self.router {
178 state.serialize_field("router", router)?;
179 }
180
181 self.env_options.serialize_fields::<S>(&mut state)?;
183 self.tls_options.serialize_fields::<S>(&mut state)?;
184
185 if let Some(manifest_path) = &self.cargo_opts.manifest_path {
186 state.serialize_field("manifest_path", manifest_path)?;
187 }
188 if self.cargo_opts.release {
189 state.serialize_field("release", &true)?;
190 }
191 if self.cargo_opts.ignore_rust_version {
192 state.serialize_field("ignore_rust_version", &true)?;
193 }
194 if self.cargo_opts.unit_graph {
195 state.serialize_field("unit_graph", &true)?;
196 }
197 if !self.cargo_opts.packages.is_empty() {
198 state.serialize_field("packages", &self.cargo_opts.packages)?;
199 }
200 if !self.cargo_opts.bin.is_empty() {
201 state.serialize_field("bin", &self.cargo_opts.bin)?;
202 }
203 if !self.cargo_opts.example.is_empty() {
204 state.serialize_field("example", &self.cargo_opts.example)?;
205 }
206 if !self.cargo_opts.args.is_empty() {
207 state.serialize_field("args", &self.cargo_opts.args)?;
208 }
209 serialize_common_options::<S>(&mut state, &self.cargo_opts.common)?;
210
211 state.end()
212 }
213}
214
215fn default_invoke_address() -> String {
216 DEFAULT_INVOKE_ADDRESS.to_string()
217}
218
219fn default_invoke_port() -> u16 {
220 DEFAULT_INVOKE_PORT
221}
222
223#[derive(Clone, Debug, Default, Deserialize, Serialize)]
224pub struct WatchConfig {
225 pub router: Option<FunctionRouter>,
226}
227
228#[derive(Clone, Debug, Default)]
229pub struct FunctionRouter {
230 inner: Router<FunctionRoutes>,
231 pub(crate) raw: Vec<Route>,
232}
233
234impl FunctionRouter {
235 pub fn at(
236 &self,
237 path: &str,
238 method: &str,
239 ) -> Result<(String, HashMap<String, String>), MatchError> {
240 let matched = self.inner.at(path)?;
241 let function = matched.value.at(method).ok_or(MatchError::NotFound)?;
242
243 let params = matched
244 .params
245 .iter()
246 .map(|(k, v)| (k.to_string(), v.to_string()))
247 .collect();
248
249 Ok((function.to_string(), params))
250 }
251
252 pub fn insert(&mut self, path: &str, routes: FunctionRoutes) -> Result<(), InsertError> {
253 self.inner.insert(path, routes)
254 }
255
256 pub fn is_empty(&self) -> bool {
257 self.raw.is_empty()
258 }
259}
260
261#[allow(dead_code)]
262fn is_empty_router(router: &Option<FunctionRouter>) -> bool {
263 router.is_none() || router.as_ref().is_some_and(|r| r.is_empty())
264}
265
266#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
267pub struct Route {
268 path: String,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 methods: Option<Vec<String>>,
271 function: String,
272}
273
274#[derive(Clone, Debug, PartialEq)]
275pub enum FunctionRoutes {
276 Single(String),
277 Multiple(HashMap<String, String>),
278}
279
280impl FunctionRoutes {
281 pub fn at(&self, method: &str) -> Option<&str> {
282 match self {
283 FunctionRoutes::Single(function) => Some(function),
284 FunctionRoutes::Multiple(routes) => routes.get(method).map(|s| s.as_str()),
285 }
286 }
287}
288
289struct FunctionRouterVisitor;
290
291impl<'de> Visitor<'de> for FunctionRouterVisitor {
292 type Value = FunctionRouter;
293
294 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
295 formatter.write_str("a map or sequence of function routes")
296 }
297
298 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
299 where
300 A: serde::de::MapAccess<'de>,
301 {
302 let routes: HashMap<String, FunctionRoutes> =
303 Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
304
305 let mut inner = Router::new();
306 let mut raw = Vec::new();
307
308 let mut inverse = HashMap::new();
309
310 for (path, route) in &routes {
311 inner.insert(path, route.clone()).map_err(|e| {
312 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
313 })?;
314
315 match route {
316 FunctionRoutes::Single(function) => {
317 raw.push(Route {
318 path: path.clone(),
319 methods: None,
320 function: function.clone(),
321 });
322 }
323 FunctionRoutes::Multiple(routes) => {
324 for (method, function) in routes {
325 inverse
326 .entry((path.clone(), function.clone()))
327 .and_modify(|route: &mut Route| {
328 let mut methods = route.methods.clone().unwrap_or_default();
329 methods.push(method.clone());
330 route.methods = Some(methods);
331 })
332 .or_insert_with(|| Route {
333 path: path.clone(),
334 methods: Some(vec![method.clone()]),
335 function: function.clone(),
336 });
337 }
338 }
339 }
340 }
341
342 for (_, route) in inverse {
343 raw.push(route);
344 }
345
346 Ok(FunctionRouter { inner, raw })
347 }
348
349 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
350 where
351 A: serde::de::SeqAccess<'de>,
352 {
353 let routes: Vec<Route> =
354 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
355
356 let mut inner = Router::new();
357 let mut raw = Vec::new();
358
359 let mut routes_by_path = HashMap::new();
360
361 for route in &routes {
362 routes_by_path
363 .entry(route.path.clone())
364 .and_modify(|routes| merge_routes(routes, route))
365 .or_insert_with(|| decode_route(route));
366
367 raw.push(route.clone());
368 }
369
370 for (path, route) in &routes_by_path {
371 inner.insert(path, route.clone()).map_err(|e| {
372 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
373 })?;
374 }
375
376 Ok(FunctionRouter { inner, raw })
377 }
378}
379
380fn merge_routes(routes: &mut FunctionRoutes, route: &Route) {
381 let methods = route.methods.clone().unwrap_or_default();
382 match routes {
383 FunctionRoutes::Single(function) if !methods.is_empty() => {
384 let mut tmp = HashMap::new();
385 for method in methods {
386 tmp.insert(method.clone(), function.clone());
387 }
388 *routes = FunctionRoutes::Multiple(tmp);
389 }
390 FunctionRoutes::Multiple(_) if methods.is_empty() => {
391 *routes = FunctionRoutes::Single(route.function.clone());
392 }
393 FunctionRoutes::Multiple(routes) => {
394 for method in methods {
395 routes.insert(method.clone(), route.function.clone());
396 }
397 }
398 FunctionRoutes::Single(_) => {
399 *routes = FunctionRoutes::Single(route.function.clone());
400 }
401 }
402}
403
404fn decode_route(route: &Route) -> FunctionRoutes {
405 match &route.methods {
406 Some(methods) => {
407 let mut routes = HashMap::new();
408 for method in methods {
409 routes.insert(method.clone(), route.function.clone());
410 }
411 FunctionRoutes::Multiple(routes)
412 }
413 None => FunctionRoutes::Single(route.function.clone()),
414 }
415}
416
417impl<'de> Deserialize<'de> for FunctionRouter {
418 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
419 where
420 D: serde::Deserializer<'de>,
421 {
422 deserializer.deserialize_any(FunctionRouterVisitor)
423 }
424}
425
426impl Serialize for FunctionRouter {
427 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
428 where
429 S: serde::Serializer,
430 {
431 self.raw.serialize(serializer)
432 }
433}
434
435impl<'de> Deserialize<'de> for FunctionRoutes {
436 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
437 where
438 D: serde::Deserializer<'de>,
439 {
440 let value = Value::deserialize(deserializer)?;
441 match value {
442 Value::String(s) => Ok(FunctionRoutes::Single(s)),
443 Value::Array(arr) => {
444 let mut routes = HashMap::new();
445 for item in arr {
446 let obj = item.as_object().ok_or_else(|| {
447 Error::custom("Array items must be objects with method and function fields")
448 })?;
449
450 let method = obj
451 .get("method")
452 .and_then(|m| m.as_str())
453 .ok_or_else(|| Error::custom("Missing or invalid method field"))?;
454
455 let function = obj
456 .get("function")
457 .and_then(|f| f.as_str())
458 .ok_or_else(|| Error::custom("Missing or invalid function field"))?;
459
460 routes.insert(method.to_string(), function.to_string());
461 }
462 Ok(FunctionRoutes::Multiple(routes))
463 }
464 _ => Err(Error::custom(
465 "Function routes must be either a string or an array of objects with method and function fields",
466 )),
467 }
468 }
469}
470
471impl Serialize for FunctionRoutes {
472 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
473 where
474 S: serde::Serializer,
475 {
476 match self {
477 FunctionRoutes::Single(function) => function.serialize(serializer),
478 FunctionRoutes::Multiple(routes) => {
479 let mut seq = serializer.serialize_seq(Some(routes.len()))?;
480 for (method, function) in routes {
481 let mut map = serde_json::Map::new();
482 map.insert("method".to_string(), json!(method));
483 map.insert("function".to_string(), json!(function));
484 seq.serialize_element(&Value::Object(map))?;
485 }
486 seq.end()
487 }
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494
495 use cargo_options::CommonOptions;
496 use serde_json::{Value, json};
497 use std::path::PathBuf;
498
499 use super::*;
500
501 #[test]
502 fn test_router_deserialize() {
503 let router: FunctionRouter = toml::from_str(
504 r#"
505 "/api/v1/users" = [
506 { function = "get_user", method = "GET" },
507 { function = "create_user", method = "POST" }
508 ]
509 "/api/v1/all_methods" = "all_methods"
510 "#,
511 )
512 .unwrap();
513
514 assert_eq!(
515 router.inner.at("/api/v1/users").unwrap().value,
516 &FunctionRoutes::Multiple(HashMap::from([
517 ("GET".to_string(), "get_user".to_string()),
518 ("POST".to_string(), "create_user".to_string()),
519 ]))
520 );
521
522 assert_eq!(
523 router.inner.at("/api/v1/all_methods").unwrap().value,
524 &FunctionRoutes::Single("all_methods".to_string())
525 );
526 }
527
528 #[test]
529 fn test_router_get() {
530 let router = FunctionRouter::default();
531 assert_eq!(router.at("/api/v1/users", "GET"), Err(MatchError::NotFound));
532
533 let mut inner = Router::new();
534 inner
535 .insert(
536 "/api/v1/users",
537 FunctionRoutes::Single("user_handler".to_string()),
538 )
539 .unwrap();
540 let router = FunctionRouter {
541 inner,
542 ..Default::default()
543 };
544 assert_eq!(
545 router.at("/api/v1/users", "GET"),
546 Ok(("user_handler".to_string(), HashMap::new()))
547 );
548 assert_eq!(
549 router.at("/api/v1/users", "POST"),
550 Ok(("user_handler".to_string(), HashMap::new()))
551 );
552
553 let mut inner = Router::new();
554 inner
555 .insert(
556 "/api/v1/users",
557 FunctionRoutes::Multiple(HashMap::from([
558 ("GET".to_string(), "get_user".to_string()),
559 ("POST".to_string(), "create_user".to_string()),
560 ])),
561 )
562 .unwrap();
563 let router = FunctionRouter {
564 inner,
565 ..Default::default()
566 };
567 assert_eq!(
568 router.at("/api/v1/users", "GET"),
569 Ok(("get_user".to_string(), HashMap::new()))
570 );
571 assert_eq!(
572 router.at("/api/v1/users", "POST"),
573 Ok(("create_user".to_string(), HashMap::new()))
574 );
575 assert_eq!(router.at("/api/v1/users", "PUT"), Err(MatchError::NotFound));
576
577 let mut inner = Router::new();
578 inner
579 .insert(
580 "/api/v1/users/{id}",
581 FunctionRoutes::Single("user_handler".to_string()),
582 )
583 .unwrap();
584 let router = FunctionRouter {
585 inner,
586 ..Default::default()
587 };
588
589 let (function, params) = router.at("/api/v1/users/1", "GET").unwrap();
590 assert_eq!(function, "user_handler");
591 assert_eq!(params, HashMap::from([("id".to_string(), "1".to_string())]));
592 }
593
594 #[test]
595 fn test_router_serialize() {
596 let config = r#"
597 "/api/v1/users" = [
598 { function = "get_user", method = "GET" },
599 { function = "create_user", method = "POST" }
600 ]
601 "/api/v1/all_methods" = "all_methods"
602 "#;
603 let router: FunctionRouter = toml::from_str(config).unwrap();
604
605 let json = serde_json::to_value(&router).unwrap();
606
607 let new_router: FunctionRouter = serde_json::from_value(json).unwrap();
608 assert_eq!(new_router.raw, router.raw);
609
610 assert_eq!(
611 new_router.inner.at("/api/v1/users").unwrap().value,
612 &FunctionRoutes::Multiple(HashMap::from([
613 ("GET".to_string(), "get_user".to_string()),
614 ("POST".to_string(), "create_user".to_string()),
615 ]))
616 );
617
618 assert_eq!(
619 new_router.inner.at("/api/v1/all_methods").unwrap().value,
620 &FunctionRoutes::Single("all_methods".to_string())
621 );
622 }
623
624 #[test]
625 fn test_watch_serialization() {
626 let watch = Watch {
627 invoke_address: "127.0.0.1".to_string(),
628 invoke_port: 9000,
629 env_options: EnvOptions {
630 env_file: Some(PathBuf::from("/tmp/env")),
631 env_var: Some(vec!["FOO=BAR".to_string()]),
632 },
633 tls_options: TlsOptions::new(
634 Some(PathBuf::from("/tmp/cert.pem")),
635 Some(PathBuf::from("/tmp/key.pem")),
636 Some(PathBuf::from("/tmp/ca.pem")),
637 ),
638 cargo_opts: Run {
639 common: CommonOptions {
640 quiet: false,
641 jobs: None,
642 keep_going: false,
643 profile: None,
644 features: vec!["feature1".to_string()],
645 all_features: false,
646 no_default_features: true,
647 target: vec!["x86_64-unknown-linux-gnu".to_string()],
648 target_dir: Some(PathBuf::from("/tmp/target")),
649 message_format: vec!["json".to_string()],
650 verbose: 1,
651 color: Some("auto".to_string()),
652 frozen: true,
653 locked: true,
654 offline: true,
655 config: vec!["config.toml".to_string()],
656 unstable_flags: vec!["flag1".to_string()],
657 timings: None,
658 },
659 manifest_path: None,
660 release: false,
661 ignore_rust_version: false,
662 unit_graph: false,
663 packages: vec![],
664 bin: vec![],
665 example: vec![],
666 args: vec![],
667 },
668 ..Default::default()
669 };
670
671 let json = serde_json::to_value(&watch).unwrap();
672 assert_eq!(json["invoke_address"], "127.0.0.1");
673 assert_eq!(json["invoke_port"], 9000);
674 assert_eq!(json["env_file"], "/tmp/env");
675 assert_eq!(json["env_var"], json!(["FOO=BAR"]));
676 assert_eq!(json["tls_cert"], "/tmp/cert.pem");
677 assert_eq!(json["tls_key"], "/tmp/key.pem");
678 assert_eq!(json["tls_ca"], "/tmp/ca.pem");
679 assert_eq!(json["features"], json!(["feature1"]));
680 assert_eq!(json["no_default_features"], true);
681 assert_eq!(json["target"], json!(["x86_64-unknown-linux-gnu"]));
682 assert_eq!(json["target_dir"], "/tmp/target");
683 assert_eq!(json["message_format"], json!(["json"]));
684 assert_eq!(json["verbose"], 1);
685 assert_eq!(json["color"], "auto");
686 assert_eq!(json["frozen"], true);
687 assert_eq!(json["locked"], true);
688 assert_eq!(json["offline"], true);
689 assert_eq!(json["config"], json!(["config.toml"]));
690 assert_eq!(json["unstable_flags"], json!(["flag1"]));
691 assert_eq!(json["timings"], Value::Null);
692
693 let deserialized: Watch = serde_json::from_value(json).unwrap();
694
695 assert_eq!(deserialized.invoke_address, watch.invoke_address);
696 assert_eq!(deserialized.invoke_port, watch.invoke_port);
697 assert_eq!(
698 deserialized.env_options.env_file,
699 watch.env_options.env_file
700 );
701 assert_eq!(deserialized.env_options.env_var, watch.env_options.env_var);
702 assert_eq!(
703 deserialized.tls_options.tls_cert,
704 watch.tls_options.tls_cert
705 );
706 assert_eq!(deserialized.tls_options.tls_key, watch.tls_options.tls_key);
707 assert_eq!(deserialized.tls_options.tls_ca, watch.tls_options.tls_ca);
708 assert_eq!(
709 deserialized.cargo_opts.common.features,
710 watch.cargo_opts.common.features
711 );
712 assert_eq!(
713 deserialized.cargo_opts.common.no_default_features,
714 watch.cargo_opts.common.no_default_features
715 );
716 assert_eq!(
717 deserialized.cargo_opts.common.target,
718 watch.cargo_opts.common.target
719 );
720 assert_eq!(
721 deserialized.cargo_opts.common.target_dir,
722 watch.cargo_opts.common.target_dir
723 );
724 assert_eq!(
725 deserialized.cargo_opts.common.message_format,
726 watch.cargo_opts.common.message_format
727 );
728 assert_eq!(
729 deserialized.cargo_opts.common.verbose,
730 watch.cargo_opts.common.verbose
731 );
732 assert_eq!(
733 deserialized.cargo_opts.common.color,
734 watch.cargo_opts.common.color
735 );
736 assert_eq!(
737 deserialized.cargo_opts.common.frozen,
738 watch.cargo_opts.common.frozen
739 );
740 assert_eq!(
741 deserialized.cargo_opts.common.locked,
742 watch.cargo_opts.common.locked
743 );
744 assert_eq!(
745 deserialized.cargo_opts.common.offline,
746 watch.cargo_opts.common.offline
747 );
748 assert_eq!(
749 deserialized.cargo_opts.common.config,
750 watch.cargo_opts.common.config
751 );
752 assert_eq!(
753 deserialized.cargo_opts.common.unstable_flags,
754 watch.cargo_opts.common.unstable_flags
755 );
756 assert_eq!(
757 deserialized.cargo_opts.common.timings,
758 watch.cargo_opts.common.timings
759 );
760 }
761}