trustless-provider-lambda 0.3.3

AWS Lambda key provider for trustless
struct LambdaHandler {
    client: aws_sdk_lambda::Client,
    function_name: String,
}

impl LambdaHandler {
    async fn invoke_lambda(
        &self,
        request: trustless_protocol::message::Request,
    ) -> Result<trustless_protocol::message::Response, trustless_protocol::message::ErrorCode> {
        let payload = serde_json::to_vec(&request).map_err(|e| {
            trustless_protocol::message::ErrorCode::Internal(format!(
                "failed to serialize request: {e}"
            ))
        })?;

        let result = self
            .client
            .invoke()
            .function_name(&self.function_name)
            .payload(aws_sdk_lambda::primitives::Blob::new(payload))
            .send()
            .await
            .map_err(|e| {
                trustless_protocol::message::ErrorCode::Internal(format!(
                    "Lambda invocation failed: {e}"
                ))
            })?;

        if let Some(func_error) = result.function_error() {
            let error_message = result
                .payload()
                .and_then(|p| String::from_utf8(p.as_ref().to_vec()).ok())
                .unwrap_or_else(|| func_error.to_owned());
            return Err(trustless_protocol::message::ErrorCode::Internal(format!(
                "Lambda function error: {error_message}"
            )));
        }

        let response_payload = result.payload().ok_or_else(|| {
            trustless_protocol::message::ErrorCode::Internal(
                "Lambda returned no payload".to_owned(),
            )
        })?;

        let response: trustless_protocol::message::Response =
            serde_json::from_slice(response_payload.as_ref()).map_err(|e| {
                trustless_protocol::message::ErrorCode::Internal(format!(
                    "failed to deserialize Lambda response: {e}"
                ))
            })?;

        Ok(response)
    }
}

impl trustless_protocol::handler::Handler for LambdaHandler {
    async fn initialize(
        &self,
    ) -> Result<trustless_protocol::message::InitializeResult, trustless_protocol::message::ErrorCode>
    {
        let request = trustless_protocol::message::Request::Initialize {
            id: 0,
            params: trustless_protocol::message::InitializeParams {},
        };
        let response = self.invoke_lambda(request).await?;
        match response {
            trustless_protocol::message::Response::Success(
                trustless_protocol::message::SuccessResponse::Initialize { result, .. },
            ) => Ok(result),
            trustless_protocol::message::Response::Success(_) => {
                Err(trustless_protocol::message::ErrorCode::Internal(
                    "unexpected response method".to_owned(),
                ))
            }
            trustless_protocol::message::Response::Error(
                trustless_protocol::message::ErrorResponse { error, .. },
            ) => Err(error.into()),
        }
    }

    async fn sign(
        &self,
        params: trustless_protocol::message::SignParams,
    ) -> Result<trustless_protocol::message::SignResult, trustless_protocol::message::ErrorCode>
    {
        let request = trustless_protocol::message::Request::Sign { id: 0, params };
        let response = self.invoke_lambda(request).await?;
        match response {
            trustless_protocol::message::Response::Success(
                trustless_protocol::message::SuccessResponse::Sign { result, .. },
            ) => Ok(result),
            trustless_protocol::message::Response::Success(_) => {
                Err(trustless_protocol::message::ErrorCode::Internal(
                    "unexpected response method".to_owned(),
                ))
            }
            trustless_protocol::message::Response::Error(
                trustless_protocol::message::ErrorResponse { error, .. },
            ) => Err(error.into()),
        }
    }
}

#[derive(clap::Parser)]
struct Args {
    #[clap(long)]
    function_name: String,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    use clap::Parser as _;

    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn")),
        )
        .with_writer(std::io::stderr)
        .init();

    let args = Args::parse();

    let aws_config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
    let lambda_client = aws_sdk_lambda::Client::new(&aws_config);

    let handler = LambdaHandler {
        client: lambda_client,
        function_name: args.function_name,
    };

    tracing::info!(
        function_name = %handler.function_name,
        "starting Lambda provider"
    );

    trustless_protocol::handler::run(handler).await?;

    Ok(())
}