ockam_api 0.93.0

Ockam's request-response API
use std::io::Write;

use ockam_core::async_trait;
use ockam_node::Context;
use ockam_transport_tcp::{Direction, PortalInterceptor, PortalInterceptorFactory};
use std::sync::Arc;
use tokio::sync::Mutex;

use super::token_lease_refresher::TokenLeaseRefresher;
use crate::http::state::{ClientRequestWriter, RequestState};
use tracing::{debug, error};

struct HttpAuthInterceptorState {
    state: RequestState,
}

struct HttpAuthInterceptor {
    state: Arc<Mutex<HttpAuthInterceptorState>>,
    token_refresher: TokenLeaseRefresher,
}

impl HttpAuthInterceptor {
    fn new(token_refresher: TokenLeaseRefresher) -> Self {
        let state = HttpAuthInterceptorState {
            state: RequestState::ParsingHeader(None),
        };
        Self {
            state: Arc::new(Mutex::new(state)),
            token_refresher,
        }
    }
}

pub struct HttpAuthInterceptorFactory {
    token_refresher: TokenLeaseRefresher,
}

impl HttpAuthInterceptorFactory {
    pub fn new(token_refresher: TokenLeaseRefresher) -> Self {
        Self { token_refresher }
    }
}

impl PortalInterceptorFactory for HttpAuthInterceptorFactory {
    fn create(&self) -> Arc<dyn PortalInterceptor> {
        Arc::new(HttpAuthInterceptor::new(self.token_refresher.clone()))
    }
}

fn attach_auth_token_and_serialize_into(
    req: &httparse::Request,
    token: &str,
    buffer: &mut Vec<u8>,
) {
    debug!("Serializing http req header");
    write!(
        buffer,
        "{} {} HTTP/1.{}\r\n",
        req.method.unwrap(),
        req.path.unwrap(),
        req.version.unwrap()
    )
    .unwrap();

    write!(buffer, "Authorization: Token {}\r\n", token).unwrap();
    for h in &*req.headers {
        if !h.name.eq_ignore_ascii_case("Authorization") {
            write!(buffer, "{}: ", h.name).unwrap();
            buffer.extend_from_slice(h.value);
            buffer.extend_from_slice(b"\r\n");
        }
    }
    buffer.extend_from_slice(b"\r\n");
}

#[async_trait]
impl PortalInterceptor for HttpAuthInterceptor {
    async fn intercept(
        &self,
        _context: &mut Context,
        direction: Direction,
        buffer: &[u8],
    ) -> ockam_core::Result<Option<Vec<u8>>> {
        match direction {
            Direction::FromOutletToInlet => Ok(Some(buffer.to_vec())),
            Direction::FromInletToOutlet => {
                let mut guard = self.state.lock().await;
                let out = guard.state.process_http_buffer(buffer, self)?;
                Ok(Some(out))
            }
        }
    }
}

impl ClientRequestWriter for &HttpAuthInterceptor {
    fn write_headers(
        &self,
        request: &httparse::Request,
        buffer: &mut Vec<u8>,
    ) -> ockam_core::Result<()> {
        let token = self.token_refresher.get_token();
        if token.is_none() {
            error!("No authorization token available");
        }

        attach_auth_token_and_serialize_into(request, &token.unwrap_or_default(), buffer);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use crate::http::state::{ClientRequestWriter, RequestState};
    use crate::influxdb::gateway::interceptor::attach_auth_token_and_serialize_into;

    const REQ: &str = "POST / HTTP/1.1\r\n\
Host: www.example.com\r\n\
User-Agent: Mozilla/5.0\r\n\
Accept-Encoding: gzip, deflate, br\r\n\
Transfer-Encoding: gzip, chunked\r\n\r\n\
4\r\nWiki\r\n7\r\npedia i\r\n0\r\n\r\n";

    const TOKEN: &str = "SAMPLE-TOKEN";

    const EXPECTED: &str = "POST / HTTP/1.1\r\n\
Authorization: Token SAMPLE-TOKEN\r\n\
Host: www.example.com\r\n\
User-Agent: Mozilla/5.0\r\n\
Accept-Encoding: gzip, deflate, br\r\n\
Transfer-Encoding: gzip, chunked\r\n\r\n\
4\r\nWiki\r\n7\r\npedia i\r\n0\r\n\r\n";

    struct RequestWriterSimulator;

    impl ClientRequestWriter for RequestWriterSimulator {
        fn write_headers(
            &self,
            request: &httparse::Request,
            buffer: &mut Vec<u8>,
        ) -> ockam_core::Result<()> {
            attach_auth_token_and_serialize_into(request, TOKEN, buffer);
            Ok(())
        }
    }

    #[test]
    fn parse_post_with_chunked_transfers() {
        let mut data = Vec::new();
        data.extend_from_slice(REQ.as_bytes());
        data.extend_from_slice(REQ.as_bytes());

        for size in [1, 5, 32, 1024] {
            let mut result = Vec::new();
            let mut request_state = RequestState::ParsingHeader(None);
            for chunk in data.chunks(size) {
                let data_out = request_state
                    .process_http_buffer(chunk, RequestWriterSimulator)
                    .unwrap();
                result.extend_from_slice(&data_out);
            }
            assert_eq!(
                String::from_utf8(result).unwrap(),
                EXPECTED.to_owned() + EXPECTED
            );
            assert_eq!(request_state, RequestState::ParsingHeader(None));
        }
    }

    #[test]
    fn parse_post_with_content_length() {
        let req = "POST /test HTTP/1.1\r\n\
Host: foo.example\r\n\
Content-Type: application/x-www-form-urlencoded\r\n\
Content-Length: 27\r\n\r\n\
field1=value1&field2=value2";
        let expected_r = format!(
            "POST /test HTTP/1.1\r\n\
Authorization: Token {}\r\n\
Host: foo.example\r\n\
Content-Type: application/x-www-form-urlencoded\r\n\
Content-Length: 27\r\n\r\n\
field1=value1&field2=value2",
            TOKEN
        );

        let data = [req.as_bytes(), req.as_bytes()].concat();
        let expected = [expected_r.as_bytes(), expected_r.as_bytes()].concat();

        for size in [1, 5, 32, 1024] {
            let mut result = Vec::new();
            let mut request_state = RequestState::ParsingHeader(None);
            for chunk in data.chunks(size) {
                let data_out = request_state
                    .process_http_buffer(chunk, RequestWriterSimulator)
                    .unwrap();
                result.extend_from_slice(&data_out);
            }
            assert_eq!(
                String::from_utf8(result).unwrap(),
                String::from_utf8(expected.clone()).unwrap()
            );
            assert_eq!(request_state, RequestState::ParsingHeader(None));
        }
    }

    #[test]
    fn parse_get_requests() {
        let req = "GET /home/user/example.txt HTTP/1.1\r\n\r\n";
        let mut data = Vec::new();
        data.extend_from_slice(req.as_bytes());
        data.extend_from_slice(req.as_bytes());

        let mut expected = format!(
            "GET /home/user/example.txt HTTP/1.1\r\nAuthorization: Token {}\r\n\r\n",
            TOKEN
        );
        expected = expected.clone() + &expected;

        for size in [1, 5, 32, 1024] {
            let mut result = Vec::new();
            let mut request_state = RequestState::ParsingHeader(None);
            for chunk in data.chunks(size) {
                let data_out = request_state
                    .process_http_buffer(chunk, RequestWriterSimulator)
                    .unwrap();
                result.extend_from_slice(&data_out);
            }
            assert_eq!(String::from_utf8(result).unwrap(), expected);
            assert_eq!(request_state, RequestState::ParsingHeader(None));
        }
    }
}