web-push 0.7.2

Web push notification client with support for http-ece encryption and VAPID authentication.
Documentation
use base64;
use crate::error::WebPushError;
use http::header::{AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE};
use hyper::{Body, Request, StatusCode};
use crate::message::WebPushMessage;
use serde_json;

#[derive(Deserialize, Serialize, Debug)]
pub enum GcmError {
    MissingRegistration,
    InvalidRegistration,
    NotRegistered,
    InvalidPackageName,
    MismatchSenderId,
    InvalidParameters,
    MessageTooBig,
    InvalidDataKey,
    InvalidTtl,
    Unavailable,
    InternalServerError,
    DeviceMessageRateExceeded,
    TopicsMessageRateExceeded,
    InvalidApnsCredential,
}

impl<'a> From<&'a GcmError> for WebPushError {
    fn from(e: &GcmError) -> WebPushError {
        match e {
            &GcmError::MissingRegistration => WebPushError::EndpointNotFound,
            &GcmError::InvalidRegistration => WebPushError::EndpointNotValid,
            &GcmError::NotRegistered => WebPushError::EndpointNotValid,
            &GcmError::InvalidPackageName => WebPushError::InvalidPackageName,
            &GcmError::MessageTooBig => WebPushError::PayloadTooLarge,
            &GcmError::InvalidTtl => WebPushError::InvalidTtl,
            &GcmError::Unavailable => WebPushError::ServerError(None),
            &GcmError::InternalServerError => WebPushError::ServerError(None),
            e => WebPushError::Other(format!("{:?}", e)),
        }
    }
}

#[derive(Deserialize, Serialize, Debug)]
pub struct GcmResponse {
    pub message_id: Option<u64>,
    pub error: Option<GcmError>,
    pub multicast_id: Option<i64>,
    pub success: Option<u64>,
    pub failure: Option<u64>,
    pub canonical_ids: Option<u64>,
    pub results: Option<Vec<MessageResult>>,
}

#[derive(Deserialize, Serialize, Debug)]
pub struct MessageResult {
    pub message_id: Option<String>,
    pub registration_id: Option<String>,
    pub error: Option<GcmError>,
}

#[derive(Deserialize, Serialize)]
struct GcmData {
    registration_ids: Vec<String>,
    raw_data: Option<String>,
}

pub fn build_request(message: WebPushMessage) -> Request<Body> {
    let uri = match message.endpoint.host() {
        Some("fcm.googleapis.com") => "https://fcm.googleapis.com/fcm/send",
        _ => "https://android.googleapis.com/gcm/send",
    };

    let mut builder = Request::builder()
        .method("POST")
        .uri(uri);

    if let Some(ref gcm_key) = message.gcm_key {
        builder = builder.header(AUTHORIZATION, format!("key={}", gcm_key).as_bytes());
    }

    let mut registration_ids = Vec::with_capacity(1);

    if let Some(token) = message.endpoint.path().split("/").last() {
        registration_ids.push(token.to_string());
    }

    let raw_data = match message.payload {
        Some(payload) => {
            for (k, v) in payload.crypto_headers.into_iter() {
                let v: &str = v.as_ref();
                builder = builder.header(k, v);
            }

            Some(base64::encode(&payload.content))
        }
        None => None,
    };

    let gcm_data = GcmData {
        registration_ids: registration_ids,
        raw_data: raw_data,
    };

    let json_payload = serde_json::to_string(&gcm_data).unwrap();

    builder
        .header(CONTENT_TYPE, "application/json")
        .header(
            CONTENT_LENGTH,
            format!("{}", json_payload.len() as u64).as_bytes(),
        )
        .body(json_payload.into()).unwrap()
}

pub fn parse_response(response_status: StatusCode, body: Vec<u8>) -> Result<(), WebPushError> {
    match response_status {
        StatusCode::OK => {
            let body_str = String::from_utf8(body)?;
            let gcm_response: GcmResponse = serde_json::from_str(&body_str)?;

            if let Some(0) = gcm_response.failure {
                Ok(())
            } else {
                match gcm_response.results {
                    Some(results) => match results.first() {
                        Some(result) => match result.error {
                            Some(ref error) => Err(WebPushError::from(error)),
                            _ => Err(WebPushError::Other(String::from("UnknownError"))),
                        },
                        _ => Err(WebPushError::Other(String::from("UnknownError"))),
                    },
                    _ => Err(WebPushError::Other(String::from("UnknownError"))),
                }
            }
        }
        StatusCode::UNAUTHORIZED => Err(WebPushError::Unauthorized),
        StatusCode::BAD_REQUEST => {
            let body_str = String::from_utf8(body)?;
            let gcm_response: GcmResponse = serde_json::from_str(&body_str)?;

            match gcm_response.error {
                Some(e) => Err(WebPushError::from(&e)),
                _ => Err(WebPushError::BadRequest(None)),
            }
        }
        status if status.is_server_error() => Err(WebPushError::ServerError(None)),
        e => Err(WebPushError::Other(format!("{:?}", e))),
    }
}

#[cfg(test)]
mod tests {
    use crate::error::WebPushError;
    use crate::http_ece::ContentEncoding;
    use hyper::StatusCode;
    use hyper::Uri;
    use crate::message::{SubscriptionInfo, WebPushMessageBuilder};
    use crate::services::firebase::*;

    #[test]
    fn builds_a_correct_request_with_empty_payload() {
        let info = SubscriptionInfo::new(
            "https://android.googleapis.com/gcm/send/device_token_2",
            "BLMbF9ffKBiWQLCKvTHb6LO8Nb6dcUh6TItC455vu2kElga6PQvUmaFyCdykxY2nOSSL3yKgfbmFLRTUaGv4yV8",
            "xS03Fi5ErfTNH_l9WHE9Ig"
        );

        let mut builder = WebPushMessageBuilder::new(&info).unwrap();

        builder.set_gcm_key("test_key");
        builder.set_ttl(420);
        builder.set_payload(ContentEncoding::AesGcm, "test".as_bytes());

        let request = build_request(builder.build().unwrap());

        let authorization = request
            .headers()
            .get("Authorization")
            .unwrap()
            .to_str()
            .unwrap();

        let expected_uri: Uri = "https://android.googleapis.com/gcm/send".parse().unwrap();

        assert_eq!("key=test_key", authorization);
        assert_eq!(expected_uri.host(), request.uri().host());
        assert_eq!(expected_uri.path(), request.uri().path());
    }

    #[test]
    fn builds_a_correct_request_with_a_payload() {
        let info = SubscriptionInfo::new(
            "https://fcm.googleapis.com/gcm/send/device_token_2",
            "BLMbF9ffKBiWQLCKvTHb6LO8Nb6dcUh6TItC455vu2kElga6PQvUmaFyCdykxY2nOSSL3yKgfbmFLRTUaGv4yV8",
            "xS03Fi5ErfTNH_l9WHE9Ig"
        );

        let mut builder = WebPushMessageBuilder::new(&info).unwrap();

        builder.set_gcm_key("test_key");
        builder.set_ttl(420);
        builder.set_payload(ContentEncoding::AesGcm, "test".as_bytes());

        let request = build_request(builder.build().unwrap());

        let authorization = request
            .headers()
            .get("Authorization")
            .unwrap()
            .to_str()
            .unwrap();

        let expected_uri: Uri = "https://fcm.googleapis.com/fcm/send".parse().unwrap();
        let length = request.headers().get("Content-Length").unwrap();

        assert_eq!("key=test_key", authorization);
        assert_eq!(expected_uri.host(), request.uri().host());
        assert_eq!(expected_uri.path(), request.uri().path());
        assert_eq!("4149", length);
    }

    #[test]
    fn parses_a_successful_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 1,
            "failure": 0
        }
        "#;
        assert_eq!(
            Ok(()),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_a_missing_registration_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"MissingRegistration"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::EndpointNotFound),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_a_invalid_registration_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"InvalidRegistration"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::EndpointNotValid),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_a_not_registered_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"NotRegistered"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::EndpointNotValid),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_an_invalid_package_name_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"InvalidPackageName"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::InvalidPackageName),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_a_message_too_big_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"MessageTooBig"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::PayloadTooLarge),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_an_invalid_data_key_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"InvalidDataKey"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::Other(String::from("InvalidDataKey"))),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_an_invalid_ttl_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"InvalidTtl"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::InvalidTtl),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_an_unavailable_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"Unavailable"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::ServerError(None)),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }

    #[test]
    fn parses_an_internal_server_error_response_correctly() {
        let response = r#"
        {
            "message_id": 12,
            "multicast_id": 33,
            "success": 0,
            "failure": 1,
            "results": [{"error":"InternalServerError"}]
        }
        "#;
        assert_eq!(
            Err(WebPushError::ServerError(None)),
            parse_response(StatusCode::OK, response.as_bytes().to_vec())
        )
    }
}