1use cargo_options::Run;
2use clap::Args;
3use matchit::{InsertError, MatchError, Router};
4use serde::{
5 Deserialize, Serialize,
6 de::{Error, Visitor},
7 ser::SerializeSeq,
8};
9use serde_json::{Value, json};
10use std::{collections::HashMap, path::PathBuf};
11
12use crate::{
13 cargo::{count_common_options, serialize_common_options},
14 env::{EnvOptions, Environment},
15 error::MetadataError,
16 lambda::Timeout,
17};
18
19use cargo_lambda_remote::tls::TlsOptions;
20
21#[cfg(windows)]
22const DEFAULT_INVOKE_ADDRESS: &str = "127.0.0.1";
23
24#[cfg(not(windows))]
25const DEFAULT_INVOKE_ADDRESS: &str = "::";
26
27const DEFAULT_INVOKE_PORT: u16 = 9000;
28
29#[derive(Args, Clone, Debug, Default, Deserialize)]
30#[command(
31 name = "watch",
32 visible_alias = "start",
33 after_help = "Full command documentation: https://www.cargo-lambda.info/commands/watch.html"
34)]
35pub struct Watch {
36 #[arg(long, visible_alias = "no-reload")]
38 #[serde(default)]
39 pub ignore_changes: bool,
40
41 #[arg(long)]
44 #[serde(default)]
45 pub only_lambda_apis: bool,
46
47 #[arg(short = 'a', long, default_value = DEFAULT_INVOKE_ADDRESS)]
48 #[serde(default = "default_invoke_address")]
49 pub invoke_address: String,
51
52 #[arg(short = 'P', long, default_value_t = DEFAULT_INVOKE_PORT)]
54 #[serde(default = "default_invoke_port")]
55 pub invoke_port: u16,
56
57 #[arg(long)]
59 #[serde(default)]
60 pub print_traces: bool,
61
62 #[arg(long, short)]
64 #[serde(default)]
65 pub wait: bool,
66
67 #[arg(long)]
69 #[serde(default)]
70 pub disable_cors: bool,
71
72 #[arg(long)]
74 #[serde(default)]
75 pub timeout: Option<Timeout>,
76
77 #[command(flatten)]
78 #[serde(flatten)]
79 pub cargo_opts: Run,
80
81 #[command(flatten)]
82 #[serde(flatten)]
83 pub env_options: EnvOptions,
84
85 #[command(flatten)]
86 #[serde(flatten)]
87 pub tls_options: TlsOptions,
88
89 #[arg(skip)]
90 #[serde(default)]
91 pub router: Option<FunctionRouter>,
92}
93
94impl Watch {
95 pub fn manifest_path(&self) -> PathBuf {
96 self.cargo_opts
97 .manifest_path
98 .clone()
99 .unwrap_or_else(|| "Cargo.toml".into())
100 }
101
102 pub fn package(&self) -> Option<String> {
105 if self.cargo_opts.packages.len() > 1 {
106 return None;
107 }
108 self.cargo_opts.packages.first().map(|s| s.to_string())
109 }
110
111 pub fn lambda_environment(
112 &self,
113 base: &HashMap<String, String>,
114 ) -> Result<Environment, MetadataError> {
115 self.env_options.lambda_environment(base)
116 }
117}
118
119impl Serialize for Watch {
120 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121 where
122 S: serde::Serializer,
123 {
124 use serde::ser::SerializeStruct;
125
126 let field_count = self.ignore_changes as usize
128 + self.only_lambda_apis as usize
129 + !self.invoke_address.is_empty() as usize
130 + (self.invoke_port != 0) as usize
131 + self.print_traces as usize
132 + self.wait as usize
133 + self.disable_cors as usize
134 + self.timeout.is_some() as usize
135 + self.router.is_some() as usize
136 + self.cargo_opts.manifest_path.is_some() as usize
137 + self.cargo_opts.release as usize
138 + self.cargo_opts.ignore_rust_version as usize
139 + self.cargo_opts.unit_graph as usize
140 + !self.cargo_opts.packages.is_empty() as usize
141 + !self.cargo_opts.bin.is_empty() as usize
142 + !self.cargo_opts.example.is_empty() as usize
143 + !self.cargo_opts.args.is_empty() as usize
144 + count_common_options(&self.cargo_opts.common)
145 + self.env_options.count_fields()
146 + self.tls_options.count_fields();
147
148 let mut state = serializer.serialize_struct("Watch", field_count)?;
149
150 if self.ignore_changes {
152 state.serialize_field("ignore_changes", &true)?;
153 }
154 if self.only_lambda_apis {
155 state.serialize_field("only_lambda_apis", &true)?;
156 }
157 if !self.invoke_address.is_empty() {
158 state.serialize_field("invoke_address", &self.invoke_address)?;
159 }
160 if self.invoke_port != 0 {
161 state.serialize_field("invoke_port", &self.invoke_port)?;
162 }
163 if self.print_traces {
164 state.serialize_field("print_traces", &true)?;
165 }
166 if self.wait {
167 state.serialize_field("wait", &true)?;
168 }
169 if self.disable_cors {
170 state.serialize_field("disable_cors", &true)?;
171 }
172
173 if let Some(timeout) = &self.timeout {
175 state.serialize_field("timeout", timeout)?;
176 }
177 if let Some(router) = &self.router {
178 state.serialize_field("router", router)?;
179 }
180
181 self.env_options.serialize_fields::<S>(&mut state)?;
183 self.tls_options.serialize_fields::<S>(&mut state)?;
184
185 if let Some(manifest_path) = &self.cargo_opts.manifest_path {
186 state.serialize_field("manifest_path", manifest_path)?;
187 }
188 if self.cargo_opts.release {
189 state.serialize_field("release", &true)?;
190 }
191 if self.cargo_opts.ignore_rust_version {
192 state.serialize_field("ignore_rust_version", &true)?;
193 }
194 if self.cargo_opts.unit_graph {
195 state.serialize_field("unit_graph", &true)?;
196 }
197 if !self.cargo_opts.packages.is_empty() {
198 state.serialize_field("packages", &self.cargo_opts.packages)?;
199 }
200 if !self.cargo_opts.bin.is_empty() {
201 state.serialize_field("bin", &self.cargo_opts.bin)?;
202 }
203 if !self.cargo_opts.example.is_empty() {
204 state.serialize_field("example", &self.cargo_opts.example)?;
205 }
206 if !self.cargo_opts.args.is_empty() {
207 state.serialize_field("args", &self.cargo_opts.args)?;
208 }
209 serialize_common_options::<S>(&mut state, &self.cargo_opts.common)?;
210
211 state.end()
212 }
213}
214
215fn default_invoke_address() -> String {
216 DEFAULT_INVOKE_ADDRESS.to_string()
217}
218
219fn default_invoke_port() -> u16 {
220 DEFAULT_INVOKE_PORT
221}
222
223#[derive(Clone, Debug, Default, Deserialize, Serialize)]
224pub struct WatchConfig {
225 pub router: Option<FunctionRouter>,
226}
227
228#[derive(Clone, Debug, Default)]
229pub struct FunctionRouter {
230 inner: Router<FunctionRoutes>,
231 pub(crate) raw: Vec<Route>,
232}
233
234impl FunctionRouter {
235 pub fn at(
236 &self,
237 path: &str,
238 method: &str,
239 ) -> Result<(String, HashMap<String, String>), MatchError> {
240 let matched = self.inner.at(path)?;
241 let function = matched.value.at(method).ok_or(MatchError::NotFound)?;
242
243 let params = matched
244 .params
245 .iter()
246 .map(|(k, v)| (k.to_string(), v.to_string()))
247 .collect();
248
249 Ok((function.to_string(), params))
250 }
251
252 pub fn insert(&mut self, path: &str, routes: FunctionRoutes) -> Result<(), InsertError> {
253 self.inner.insert(path, routes)
254 }
255}
256
257#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
258pub struct Route {
259 path: String,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 methods: Option<Vec<String>>,
262 function: String,
263}
264
265#[derive(Clone, Debug, PartialEq)]
266pub enum FunctionRoutes {
267 Single(String),
268 Multiple(HashMap<String, String>),
269}
270
271impl FunctionRoutes {
272 pub fn at(&self, method: &str) -> Option<&str> {
273 match self {
274 FunctionRoutes::Single(function) => Some(function),
275 FunctionRoutes::Multiple(routes) => routes.get(method).map(|s| s.as_str()),
276 }
277 }
278}
279
280struct FunctionRouterVisitor;
281
282impl<'de> Visitor<'de> for FunctionRouterVisitor {
283 type Value = FunctionRouter;
284
285 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
286 formatter.write_str("a map or sequence of function routes")
287 }
288
289 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
290 where
291 A: serde::de::MapAccess<'de>,
292 {
293 let routes: HashMap<String, FunctionRoutes> =
294 Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
295
296 let mut inner = Router::new();
297 let mut raw = Vec::new();
298
299 let mut inverse = HashMap::new();
300
301 for (path, route) in &routes {
302 inner.insert(path, route.clone()).map_err(|e| {
303 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
304 })?;
305
306 match route {
307 FunctionRoutes::Single(function) => {
308 raw.push(Route {
309 path: path.clone(),
310 methods: None,
311 function: function.clone(),
312 });
313 }
314 FunctionRoutes::Multiple(routes) => {
315 for (method, function) in routes {
316 inverse
317 .entry((path.clone(), function.clone()))
318 .and_modify(|route: &mut Route| {
319 let mut methods = route.methods.clone().unwrap_or_default();
320 methods.push(method.clone());
321 route.methods = Some(methods);
322 })
323 .or_insert_with(|| Route {
324 path: path.clone(),
325 methods: Some(vec![method.clone()]),
326 function: function.clone(),
327 });
328 }
329 }
330 }
331 }
332
333 for (_, route) in inverse {
334 raw.push(route);
335 }
336
337 Ok(FunctionRouter { inner, raw })
338 }
339
340 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
341 where
342 A: serde::de::SeqAccess<'de>,
343 {
344 let routes: Vec<Route> =
345 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
346
347 let mut inner = Router::new();
348 let mut raw = Vec::new();
349
350 let mut routes_by_path = HashMap::new();
351
352 for route in &routes {
353 routes_by_path
354 .entry(route.path.clone())
355 .and_modify(|routes| merge_routes(routes, route))
356 .or_insert_with(|| decode_route(route));
357
358 raw.push(route.clone());
359 }
360
361 for (path, route) in &routes_by_path {
362 inner.insert(path, route.clone()).map_err(|e| {
363 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
364 })?;
365 }
366
367 Ok(FunctionRouter { inner, raw })
368 }
369}
370
371fn merge_routes(routes: &mut FunctionRoutes, route: &Route) {
372 let methods = route.methods.clone().unwrap_or_default();
373 match routes {
374 FunctionRoutes::Single(function) if !methods.is_empty() => {
375 let mut tmp = HashMap::new();
376 for method in methods {
377 tmp.insert(method.clone(), function.clone());
378 }
379 *routes = FunctionRoutes::Multiple(tmp);
380 }
381 FunctionRoutes::Multiple(_) if methods.is_empty() => {
382 *routes = FunctionRoutes::Single(route.function.clone());
383 }
384 FunctionRoutes::Multiple(routes) => {
385 for method in methods {
386 routes.insert(method.clone(), route.function.clone());
387 }
388 }
389 FunctionRoutes::Single(_) => {
390 *routes = FunctionRoutes::Single(route.function.clone());
391 }
392 }
393}
394
395fn decode_route(route: &Route) -> FunctionRoutes {
396 match &route.methods {
397 Some(methods) => {
398 let mut routes = HashMap::new();
399 for method in methods {
400 routes.insert(method.clone(), route.function.clone());
401 }
402 FunctionRoutes::Multiple(routes)
403 }
404 None => FunctionRoutes::Single(route.function.clone()),
405 }
406}
407
408impl<'de> Deserialize<'de> for FunctionRouter {
409 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
410 where
411 D: serde::Deserializer<'de>,
412 {
413 deserializer.deserialize_any(FunctionRouterVisitor)
414 }
415}
416
417impl Serialize for FunctionRouter {
418 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
419 where
420 S: serde::Serializer,
421 {
422 self.raw.serialize(serializer)
423 }
424}
425
426impl<'de> Deserialize<'de> for FunctionRoutes {
427 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
428 where
429 D: serde::Deserializer<'de>,
430 {
431 let value = Value::deserialize(deserializer)?;
432 match value {
433 Value::String(s) => Ok(FunctionRoutes::Single(s)),
434 Value::Array(arr) => {
435 let mut routes = HashMap::new();
436 for item in arr {
437 let obj = item.as_object().ok_or_else(|| {
438 Error::custom("Array items must be objects with method and function fields")
439 })?;
440
441 let method = obj
442 .get("method")
443 .and_then(|m| m.as_str())
444 .ok_or_else(|| Error::custom("Missing or invalid method field"))?;
445
446 let function = obj
447 .get("function")
448 .and_then(|f| f.as_str())
449 .ok_or_else(|| Error::custom("Missing or invalid function field"))?;
450
451 routes.insert(method.to_string(), function.to_string());
452 }
453 Ok(FunctionRoutes::Multiple(routes))
454 }
455 _ => Err(Error::custom(
456 "Function routes must be either a string or an array of objects with method and function fields",
457 )),
458 }
459 }
460}
461
462impl Serialize for FunctionRoutes {
463 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
464 where
465 S: serde::Serializer,
466 {
467 match self {
468 FunctionRoutes::Single(function) => function.serialize(serializer),
469 FunctionRoutes::Multiple(routes) => {
470 let mut seq = serializer.serialize_seq(Some(routes.len()))?;
471 for (method, function) in routes {
472 let mut map = serde_json::Map::new();
473 map.insert("method".to_string(), json!(method));
474 map.insert("function".to_string(), json!(function));
475 seq.serialize_element(&Value::Object(map))?;
476 }
477 seq.end()
478 }
479 }
480 }
481}
482
483#[cfg(test)]
484mod tests {
485
486 use cargo_options::CommonOptions;
487 use serde_json::{Value, json};
488 use std::path::PathBuf;
489
490 use super::*;
491
492 #[test]
493 fn test_router_deserialize() {
494 let router: FunctionRouter = toml::from_str(
495 r#"
496 "/api/v1/users" = [
497 { function = "get_user", method = "GET" },
498 { function = "create_user", method = "POST" }
499 ]
500 "/api/v1/all_methods" = "all_methods"
501 "#,
502 )
503 .unwrap();
504
505 assert_eq!(
506 router.inner.at("/api/v1/users").unwrap().value,
507 &FunctionRoutes::Multiple(HashMap::from([
508 ("GET".to_string(), "get_user".to_string()),
509 ("POST".to_string(), "create_user".to_string()),
510 ]))
511 );
512
513 assert_eq!(
514 router.inner.at("/api/v1/all_methods").unwrap().value,
515 &FunctionRoutes::Single("all_methods".to_string())
516 );
517 }
518
519 #[test]
520 fn test_router_get() {
521 let router = FunctionRouter::default();
522 assert_eq!(router.at("/api/v1/users", "GET"), Err(MatchError::NotFound));
523
524 let mut inner = Router::new();
525 inner
526 .insert(
527 "/api/v1/users",
528 FunctionRoutes::Single("user_handler".to_string()),
529 )
530 .unwrap();
531 let router = FunctionRouter {
532 inner,
533 ..Default::default()
534 };
535 assert_eq!(
536 router.at("/api/v1/users", "GET"),
537 Ok(("user_handler".to_string(), HashMap::new()))
538 );
539 assert_eq!(
540 router.at("/api/v1/users", "POST"),
541 Ok(("user_handler".to_string(), HashMap::new()))
542 );
543
544 let mut inner = Router::new();
545 inner
546 .insert(
547 "/api/v1/users",
548 FunctionRoutes::Multiple(HashMap::from([
549 ("GET".to_string(), "get_user".to_string()),
550 ("POST".to_string(), "create_user".to_string()),
551 ])),
552 )
553 .unwrap();
554 let router = FunctionRouter {
555 inner,
556 ..Default::default()
557 };
558 assert_eq!(
559 router.at("/api/v1/users", "GET"),
560 Ok(("get_user".to_string(), HashMap::new()))
561 );
562 assert_eq!(
563 router.at("/api/v1/users", "POST"),
564 Ok(("create_user".to_string(), HashMap::new()))
565 );
566 assert_eq!(router.at("/api/v1/users", "PUT"), Err(MatchError::NotFound));
567
568 let mut inner = Router::new();
569 inner
570 .insert(
571 "/api/v1/users/{id}",
572 FunctionRoutes::Single("user_handler".to_string()),
573 )
574 .unwrap();
575 let router = FunctionRouter {
576 inner,
577 ..Default::default()
578 };
579
580 let (function, params) = router.at("/api/v1/users/1", "GET").unwrap();
581 assert_eq!(function, "user_handler");
582 assert_eq!(params, HashMap::from([("id".to_string(), "1".to_string())]));
583 }
584
585 #[test]
586 fn test_router_serialize() {
587 let config = r#"
588 "/api/v1/users" = [
589 { function = "get_user", method = "GET" },
590 { function = "create_user", method = "POST" }
591 ]
592 "/api/v1/all_methods" = "all_methods"
593 "#;
594 let router: FunctionRouter = toml::from_str(config).unwrap();
595
596 let json = serde_json::to_value(&router).unwrap();
597
598 let new_router: FunctionRouter = serde_json::from_value(json).unwrap();
599 assert_eq!(new_router.raw, router.raw);
600
601 assert_eq!(
602 new_router.inner.at("/api/v1/users").unwrap().value,
603 &FunctionRoutes::Multiple(HashMap::from([
604 ("GET".to_string(), "get_user".to_string()),
605 ("POST".to_string(), "create_user".to_string()),
606 ]))
607 );
608
609 assert_eq!(
610 new_router.inner.at("/api/v1/all_methods").unwrap().value,
611 &FunctionRoutes::Single("all_methods".to_string())
612 );
613 }
614
615 #[test]
616 fn test_watch_serialization() {
617 let watch = Watch {
618 invoke_address: "127.0.0.1".to_string(),
619 invoke_port: 9000,
620 env_options: EnvOptions {
621 env_file: Some(PathBuf::from("/tmp/env")),
622 env_var: Some(vec!["FOO=BAR".to_string()]),
623 },
624 tls_options: TlsOptions::new(
625 Some(PathBuf::from("/tmp/cert.pem")),
626 Some(PathBuf::from("/tmp/key.pem")),
627 Some(PathBuf::from("/tmp/ca.pem")),
628 ),
629 cargo_opts: Run {
630 common: CommonOptions {
631 quiet: false,
632 jobs: None,
633 keep_going: false,
634 profile: None,
635 features: vec!["feature1".to_string()],
636 all_features: false,
637 no_default_features: true,
638 target: vec!["x86_64-unknown-linux-gnu".to_string()],
639 target_dir: Some(PathBuf::from("/tmp/target")),
640 message_format: vec!["json".to_string()],
641 verbose: 1,
642 color: Some("auto".to_string()),
643 frozen: true,
644 locked: true,
645 offline: true,
646 config: vec!["config.toml".to_string()],
647 unstable_flags: vec!["flag1".to_string()],
648 timings: None,
649 },
650 manifest_path: None,
651 release: false,
652 ignore_rust_version: false,
653 unit_graph: false,
654 packages: vec![],
655 bin: vec![],
656 example: vec![],
657 args: vec![],
658 },
659 ..Default::default()
660 };
661
662 let json = serde_json::to_value(&watch).unwrap();
663 assert_eq!(json["invoke_address"], "127.0.0.1");
664 assert_eq!(json["invoke_port"], 9000);
665 assert_eq!(json["env_file"], "/tmp/env");
666 assert_eq!(json["env_var"], json!(["FOO=BAR"]));
667 assert_eq!(json["tls_cert"], "/tmp/cert.pem");
668 assert_eq!(json["tls_key"], "/tmp/key.pem");
669 assert_eq!(json["tls_ca"], "/tmp/ca.pem");
670 assert_eq!(json["features"], json!(["feature1"]));
671 assert_eq!(json["no_default_features"], true);
672 assert_eq!(json["target"], json!(["x86_64-unknown-linux-gnu"]));
673 assert_eq!(json["target_dir"], "/tmp/target");
674 assert_eq!(json["message_format"], json!(["json"]));
675 assert_eq!(json["verbose"], 1);
676 assert_eq!(json["color"], "auto");
677 assert_eq!(json["frozen"], true);
678 assert_eq!(json["locked"], true);
679 assert_eq!(json["offline"], true);
680 assert_eq!(json["config"], json!(["config.toml"]));
681 assert_eq!(json["unstable_flags"], json!(["flag1"]));
682 assert_eq!(json["timings"], Value::Null);
683
684 let deserialized: Watch = serde_json::from_value(json).unwrap();
685
686 assert_eq!(deserialized.invoke_address, watch.invoke_address);
687 assert_eq!(deserialized.invoke_port, watch.invoke_port);
688 assert_eq!(
689 deserialized.env_options.env_file,
690 watch.env_options.env_file
691 );
692 assert_eq!(deserialized.env_options.env_var, watch.env_options.env_var);
693 assert_eq!(
694 deserialized.tls_options.tls_cert,
695 watch.tls_options.tls_cert
696 );
697 assert_eq!(deserialized.tls_options.tls_key, watch.tls_options.tls_key);
698 assert_eq!(deserialized.tls_options.tls_ca, watch.tls_options.tls_ca);
699 assert_eq!(
700 deserialized.cargo_opts.common.features,
701 watch.cargo_opts.common.features
702 );
703 assert_eq!(
704 deserialized.cargo_opts.common.no_default_features,
705 watch.cargo_opts.common.no_default_features
706 );
707 assert_eq!(
708 deserialized.cargo_opts.common.target,
709 watch.cargo_opts.common.target
710 );
711 assert_eq!(
712 deserialized.cargo_opts.common.target_dir,
713 watch.cargo_opts.common.target_dir
714 );
715 assert_eq!(
716 deserialized.cargo_opts.common.message_format,
717 watch.cargo_opts.common.message_format
718 );
719 assert_eq!(
720 deserialized.cargo_opts.common.verbose,
721 watch.cargo_opts.common.verbose
722 );
723 assert_eq!(
724 deserialized.cargo_opts.common.color,
725 watch.cargo_opts.common.color
726 );
727 assert_eq!(
728 deserialized.cargo_opts.common.frozen,
729 watch.cargo_opts.common.frozen
730 );
731 assert_eq!(
732 deserialized.cargo_opts.common.locked,
733 watch.cargo_opts.common.locked
734 );
735 assert_eq!(
736 deserialized.cargo_opts.common.offline,
737 watch.cargo_opts.common.offline
738 );
739 assert_eq!(
740 deserialized.cargo_opts.common.config,
741 watch.cargo_opts.common.config
742 );
743 assert_eq!(
744 deserialized.cargo_opts.common.unstable_flags,
745 watch.cargo_opts.common.unstable_flags
746 );
747 assert_eq!(
748 deserialized.cargo_opts.common.timings,
749 watch.cargo_opts.common.timings
750 );
751 }
752}