cargo_lambda_invoke/
lib.rs

1use base64::{Engine as _, engine::general_purpose as b64};
2use cargo_lambda_metadata::DEFAULT_PACKAGE_FUNCTION;
3use cargo_lambda_remote::{
4    RemoteConfig,
5    aws_sdk_lambda::{Client as LambdaClient, primitives::Blob},
6    tls::TlsOptions,
7};
8use clap::{Args, ValueHint};
9use miette::{IntoDiagnostic, Result, WrapErr};
10use reqwest::{Client, StatusCode};
11use serde::Serialize;
12use serde_json::{from_str, to_string_pretty, value::Value};
13use std::{
14    convert::TryFrom,
15    fs::{File, create_dir_all, read_to_string},
16    io::copy,
17    net::IpAddr,
18    path::PathBuf,
19    str::{FromStr, from_utf8},
20};
21use strum_macros::{Display, EnumString};
22use tracing::debug;
23
24mod error;
25use error::*;
26
27const EXAMPLES_URL: &str = "https://event-examples.cargo-lambda.info";
28
29const LAMBDA_RUNTIME_CLIENT_CONTEXT: &str = "lambda-runtime-client-context";
30const LAMBDA_RUNTIME_COGNITO_IDENTITY: &str = "lambda-runtime-cognito-identity";
31
32#[derive(Args, Clone, Debug)]
33#[command(
34    name = "invoke",
35    after_help = "Full command documentation: https://www.cargo-lambda.info/commands/invoke.html"
36)]
37pub struct Invoke {
38    #[cfg_attr(
39        target_os = "windows",
40        arg(short = 'a', long, default_value = "127.0.0.1")
41    )]
42    #[cfg_attr(
43        not(target_os = "windows"),
44        arg(short = 'A', long, default_value = "::1")
45    )]
46    /// Local address host (IPv4 or IPv6) to send invoke requests
47    invoke_address: String,
48
49    /// Local port to send invoke requests
50    #[arg(short = 'P', long, default_value = "9000")]
51    invoke_port: u16,
52
53    /// File to read the invoke payload from
54    #[arg(short = 'F', long, value_hint = ValueHint::FilePath)]
55    data_file: Option<PathBuf>,
56
57    /// Invoke payload as a string
58    #[arg(short = 'A', long)]
59    data_ascii: Option<String>,
60
61    /// Example payload from AWS Lambda Events
62    #[arg(short = 'E', long)]
63    data_example: Option<String>,
64
65    /// Invoke the function already deployed on AWS Lambda
66    #[arg(short = 'R', long)]
67    remote: bool,
68
69    #[command(flatten)]
70    remote_config: RemoteConfig,
71
72    /// JSON string representing the client context for the function invocation
73    #[arg(long)]
74    client_context_ascii: Option<String>,
75
76    /// Path to a file with the JSON representation of the client context for the function invocation
77    #[arg(long)]
78    client_context_file: Option<PathBuf>,
79
80    /// Format to render the output (text, or json)
81    #[arg(short, long, default_value_t = OutputFormat::Text)]
82    output_format: OutputFormat,
83
84    #[command(flatten)]
85    cognito: Option<CognitoIdentity>,
86
87    /// Ignore data stored in the local cache
88    #[arg(long, default_value_t = false)]
89    skip_cache: bool,
90
91    /// Name of the function to invoke
92    #[arg(default_value = DEFAULT_PACKAGE_FUNCTION)]
93    function_name: String,
94
95    #[command(flatten)]
96    tls_options: TlsOptions,
97}
98
99#[derive(Clone, Debug, Display, EnumString)]
100#[strum(ascii_case_insensitive)]
101enum OutputFormat {
102    Text,
103    Json,
104}
105
106#[derive(Args, Clone, Debug, Serialize)]
107pub struct CognitoIdentity {
108    /// The unique identity id for the Cognito credentials invoking the function.
109    #[arg(long, requires = "identity-pool-id")]
110    #[serde(rename = "cognitoIdentityId")]
111    pub identity_id: Option<String>,
112    /// The identity pool id the caller is "registered" with.
113    #[arg(long, requires = "identity-id")]
114    #[serde(rename = "cognitoIdentityPoolId")]
115    pub identity_pool_id: Option<String>,
116}
117
118impl CognitoIdentity {
119    fn is_valid(&self) -> bool {
120        self.identity_id.is_some() && self.identity_pool_id.is_some()
121    }
122}
123
124impl Invoke {
125    #[tracing::instrument(skip(self), target = "cargo_lambda")]
126    pub async fn run(&self) -> Result<()> {
127        tracing::trace!(options = ?self, "invoking function");
128
129        let data = if let Some(file) = &self.data_file {
130            read_to_string(file)
131                .into_diagnostic()
132                .wrap_err("error reading data file")?
133        } else if let Some(data) = &self.data_ascii {
134            data.clone()
135        } else if let Some(example) = &self.data_example {
136            let name = example_name(example);
137
138            let cache = dirs::cache_dir()
139                .map(|p| p.join("cargo-lambda").join("invoke-fixtures").join(&name));
140
141            match cache {
142                Some(cache) if !self.skip_cache && cache.exists() => {
143                    tracing::debug!(?cache, "using example from cache");
144                    read_to_string(cache)
145                        .into_diagnostic()
146                        .wrap_err("error reading data file")?
147                }
148                _ if self.skip_cache => download_example(&name, None, None).await?,
149                _ => download_example(&name, cache, None).await?,
150            }
151        } else {
152            return Err(InvokeError::MissingPayload.into());
153        };
154
155        let text = if self.remote {
156            self.invoke_remote(&data).await?
157        } else {
158            self.invoke_local(&data).await?
159        };
160
161        let text = match &self.output_format {
162            OutputFormat::Text => text,
163            OutputFormat::Json => {
164                let obj: Value = from_str(&text)
165                    .into_diagnostic()
166                    .wrap_err("failed to serialize response into json")?;
167
168                to_string_pretty(&obj)
169                    .into_diagnostic()
170                    .wrap_err("failed to format json output")?
171            }
172        };
173
174        println!("{text}");
175
176        Ok(())
177    }
178
179    async fn invoke_remote(&self, data: &str) -> Result<String> {
180        if self.function_name == DEFAULT_PACKAGE_FUNCTION {
181            return Err(InvokeError::InvalidFunctionName.into());
182        }
183
184        let client_context = self.client_context(true)?;
185
186        let sdk_config = self.remote_config.sdk_config(None).await;
187        let client = LambdaClient::new(&sdk_config);
188
189        let resp = client
190            .invoke()
191            .function_name(&self.function_name)
192            .set_qualifier(self.remote_config.alias.clone())
193            .payload(Blob::new(data.as_bytes()))
194            .set_client_context(client_context)
195            .send()
196            .await
197            .into_diagnostic()
198            .wrap_err("failed to invoke remote function")?;
199
200        if let Some(payload) = resp.payload {
201            let blob = payload.into_inner();
202            let data = from_utf8(&blob)
203                .into_diagnostic()
204                .wrap_err("failed to read response payload")?;
205
206            if resp.function_error.is_some() {
207                let err = RemoteInvokeError::try_from(data)?;
208                Err(err.into())
209            } else {
210                Ok(data.into())
211            }
212        } else {
213            Ok("OK".into())
214        }
215    }
216
217    async fn invoke_local(&self, data: &str) -> Result<String> {
218        let host = parse_invoke_ip_address(&self.invoke_address)?;
219
220        let (protocol, client) = if self.tls_options.is_secure() {
221            let tls = self.tls_options.client_config()?;
222            let client = Client::builder()
223                .use_preconfigured_tls(tls)
224                .build()
225                .into_diagnostic()?;
226
227            ("https", client)
228        } else {
229            ("http", Client::new())
230        };
231
232        let url = format!(
233            "{}://{}:{}/2015-03-31/functions/{}/invocations",
234            protocol, &host, self.invoke_port, &self.function_name
235        );
236
237        let mut req = client.post(url).body(data.to_string());
238        if let Some(identity) = &self.cognito {
239            if identity.is_valid() {
240                let ser = serde_json::to_string(&identity)
241                    .into_diagnostic()
242                    .wrap_err("failed to serialize Cognito's identity information")?;
243                req = req.header(LAMBDA_RUNTIME_COGNITO_IDENTITY, ser);
244            }
245        }
246        if let Some(client_context) = self.client_context(false)? {
247            req = req.header(LAMBDA_RUNTIME_CLIENT_CONTEXT, client_context);
248        }
249
250        let resp = req
251            .send()
252            .await
253            .into_diagnostic()
254            .wrap_err("error sending request to the runtime emulator")?;
255        let success = resp.status() == StatusCode::OK;
256
257        let payload = resp
258            .text()
259            .await
260            .into_diagnostic()
261            .wrap_err("error reading response body")?;
262
263        if success {
264            Ok(payload)
265        } else {
266            debug!(error = ?payload, "error received from server");
267            let err = RemoteInvokeError::try_from(payload.as_str())?;
268            Err(err.into())
269        }
270    }
271
272    fn client_context(&self, encode: bool) -> Result<Option<String>> {
273        let mut data = if let Some(file) = &self.client_context_file {
274            read_to_string(file)
275                .into_diagnostic()
276                .wrap_err("error reading client context file")?
277        } else if let Some(data) = &self.client_context_ascii {
278            data.clone()
279        } else {
280            return Ok(None);
281        };
282
283        if encode {
284            data = b64::STANDARD.encode(data)
285        }
286
287        Ok(Some(data))
288    }
289}
290
291fn example_name(example: &str) -> String {
292    let mut name = if example.starts_with("example-") {
293        example.to_string()
294    } else {
295        format!("example-{example}")
296    };
297    if !name.ends_with(".json") {
298        name.push_str(".json");
299    }
300    name
301}
302
303async fn download_example(
304    name: &str,
305    cache: Option<PathBuf>,
306    authority: Option<&str>,
307) -> Result<String> {
308    let authority = authority.unwrap_or(EXAMPLES_URL);
309    let target = format!("{authority}/{name}");
310
311    tracing::debug!(?target, "downloading remote example");
312    let response = reqwest::get(&target)
313        .await
314        .into_diagnostic()
315        .wrap_err("error dowloading example data")?;
316
317    if response.status() != StatusCode::OK {
318        Err(InvokeError::ExampleDownloadFailed(target, response).into())
319    } else {
320        let content = response
321            .text()
322            .await
323            .into_diagnostic()
324            .wrap_err("error reading example data")?;
325
326        if let Some(cache) = cache {
327            tracing::debug!(?cache, "storing example in cache");
328            create_dir_all(cache.parent().unwrap()).into_diagnostic()?;
329            let mut dest = File::create(cache).into_diagnostic()?;
330            copy(&mut content.as_bytes(), &mut dest).into_diagnostic()?;
331        }
332        Ok(content)
333    }
334}
335
336fn parse_invoke_ip_address(address: &str) -> Result<String> {
337    let invoke_address = IpAddr::from_str(address).map_err(|e| miette::miette!(e))?;
338
339    let invoke_address = match invoke_address {
340        IpAddr::V4(address) => address.to_string(),
341        IpAddr::V6(address) => format!("[{address}]"),
342    };
343
344    Ok(invoke_address)
345}
346
347#[cfg(test)]
348mod test {
349    use httpmock::MockServer;
350
351    use super::*;
352
353    #[tokio::test]
354    async fn test_download_example() {
355        let server = MockServer::start_async().await;
356
357        let mock = server.mock(|when, then| {
358            when.path("/example-apigw-request.json");
359            then.status(200)
360                .header("Content-Type", "application/json")
361                .body_from_file("../../tests/fixtures/events/example-apigw-request.json");
362        });
363
364        let data = download_example(
365            "example-apigw-request.json",
366            None,
367            Some(&format!("http://{}", server.address())),
368        )
369        .await
370        .expect("failed to download json");
371
372        mock.assert();
373        assert!(data.contains("\"path\": \"/hello/world\""));
374    }
375
376    #[tokio::test]
377    async fn test_download_example_with_cache() {
378        let server = MockServer::start_async().await;
379
380        let mock = server.mock(|when, then| {
381            when.path("/example-apigw-request.json");
382            then.status(200)
383                .header("Content-Type", "application/json")
384                .body_from_file("../../tests/fixtures/events/example-apigw-request.json");
385        });
386
387        let cache = tempfile::TempDir::new()
388            .unwrap()
389            .path()
390            .join("cargo-lambda")
391            .join("example-apigw-request.json");
392
393        let data = download_example(
394            "example-apigw-request.json",
395            Some(cache.to_path_buf()),
396            Some(&format!("http://{}", server.address())),
397        )
398        .await
399        .unwrap();
400
401        mock.assert();
402        assert!(data.contains("\"path\": \"/hello/world\""));
403        assert!(cache.exists());
404
405        let content = read_to_string(cache).unwrap();
406        assert_eq!(content, data);
407    }
408
409    #[test]
410    fn test_example_name() {
411        assert_eq!(example_name("apigw-request"), "example-apigw-request.json");
412        assert_eq!(
413            example_name("apigw-request.json"),
414            "example-apigw-request.json"
415        );
416        assert_eq!(
417            example_name("example-apigw-request"),
418            "example-apigw-request.json"
419        );
420        assert_eq!(
421            example_name("example-apigw-request.json"),
422            "example-apigw-request.json"
423        );
424    }
425}