ic_http_gateway/response/
response_handler.rs

1use crate::protocol::validate;
2use crate::{HttpGatewayResponseBody, ResponseBodyStream};
3use bytes::Bytes;
4use candid::Principal;
5use futures::{stream, Stream, StreamExt, TryStreamExt};
6use http_body::Frame;
7use http_body_util::{BodyExt, Full};
8use ic_agent::{Agent, AgentError};
9use ic_http_certification::{HttpRequest, HttpResponse, StatusCode};
10use ic_response_verification::MAX_VERIFICATION_VERSION;
11use ic_utils::interfaces::http_request::HeaderField;
12use ic_utils::{
13    call::SyncCall,
14    interfaces::http_request::{
15        HttpRequestCanister, HttpRequestStreamingCallbackAny, HttpResponse as AgentResponse,
16        StreamingCallbackHttpResponse, StreamingStrategy, Token,
17    },
18};
19
20// Limit the total number of calls to an HTTP Request loop to 1000 for now.
21static MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 1000;
22
23// Limit the total number of calls to an HTTP request look that can be verified
24static MAX_VERIFIED_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 4;
25
26// Limit the number of Stream Callbacks buffered
27static STREAM_CALLBACK_BUFFER: usize = 2;
28
29pub type AgentResponseAny = AgentResponse<Token, HttpRequestStreamingCallbackAny>;
30
31pub async fn get_body_and_streaming_body(
32    agent: &Agent,
33    response: &AgentResponseAny,
34) -> Result<HttpGatewayResponseBody, AgentError> {
35    // if we already have the full body, we can return it early
36    let Some(StreamingStrategy::Callback(callback_strategy)) = response.streaming_strategy.clone()
37    else {
38        return Ok(HttpGatewayResponseBody::Right(Full::from(
39            response.body.clone(),
40        )));
41    };
42
43    let (streamed_body, token) = create_stream(
44        agent.clone(),
45        callback_strategy.callback.clone(),
46        Some(callback_strategy.token),
47    )
48    .take(MAX_VERIFIED_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
49    .map(|x| async move { x })
50    .buffered(STREAM_CALLBACK_BUFFER)
51    .try_fold(
52        (vec![], None::<Token>),
53        |mut accum, (mut body, token)| async move {
54            accum.0.append(&mut body);
55            accum.1 = token;
56
57            Ok(accum)
58        },
59    )
60    .await?;
61
62    let streamed_body = [response.body.clone(), streamed_body].concat();
63
64    // if we still have a token at this point,
65    // we were unable to collect the response within the allowed certified callback limit,
66    // fallback to uncertified streaming using what we've streamed so far as the initial body
67    if token.is_some() {
68        let body_stream = create_body_stream(
69            agent.clone(),
70            callback_strategy.callback,
71            token,
72            streamed_body,
73        );
74
75        return Ok(HttpGatewayResponseBody::Left(body_stream));
76    };
77
78    // if we no longer have a token at this point,
79    // we were able to collect the response within the allow certified callback limit,
80    // return this collected response as a standard response body so it will be verified
81    Ok(HttpGatewayResponseBody::Right(Full::from(streamed_body)))
82}
83
84fn create_body_stream(
85    agent: Agent,
86    callback: HttpRequestStreamingCallbackAny,
87    token: Option<Token>,
88    initial_body: Vec<u8>,
89) -> ResponseBodyStream {
90    let chunks_stream = create_stream(agent, callback, token)
91        .map(|chunk| chunk.map(|(body, _)| Frame::data(Bytes::from(body))));
92
93    let body_stream = stream::once(async move { Ok(Frame::data(Bytes::from(initial_body))) })
94        .chain(chunks_stream)
95        .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
96        .map(|x| async move { x })
97        .buffered(STREAM_CALLBACK_BUFFER);
98
99    ResponseBodyStream::new(Box::pin(body_stream))
100}
101
102fn create_stream(
103    agent: Agent,
104    callback: HttpRequestStreamingCallbackAny,
105    token: Option<Token>,
106) -> impl Stream<Item = Result<(Vec<u8>, Option<Token>), AgentError>> {
107    futures::stream::try_unfold(
108        (agent, callback, token),
109        |(agent, callback, token)| async move {
110            let Some(token) = token else {
111                return Ok(None);
112            };
113
114            let canister = HttpRequestCanister::create(&agent, callback.0.principal);
115            match canister
116                .http_request_stream_callback(&callback.0.method, token)
117                .call()
118                .await
119            {
120                Ok((StreamingCallbackHttpResponse { body, token },)) => {
121                    Ok(Some(((body, token.clone()), (agent, callback, token))))
122                }
123                Err(e) => Err(e),
124            }
125        },
126    )
127}
128
129#[derive(Clone, Debug)]
130struct StreamState<'a> {
131    pub http_request: HttpRequest<'a>,
132    pub canister_id: Principal,
133    pub total_length: usize,
134    pub fetched_length: usize,
135    pub skip_verification: bool,
136}
137
138pub async fn get_206_stream_response_body_and_total_length(
139    agent: &Agent,
140    http_request: HttpRequest<'static>,
141    canister_id: Principal,
142    response_headers: &Vec<HeaderField<'static>>,
143    response_206_body: HttpGatewayResponseBody,
144    skip_verification: bool,
145) -> Result<(HttpGatewayResponseBody, usize), AgentError> {
146    let HttpGatewayResponseBody::Right(body) = response_206_body else {
147        return Err(AgentError::InvalidHttpResponse(
148            "Expected full 206 response".to_string(),
149        ));
150    };
151    // The expect below should never panic because `Either::Right` will always have a full body
152    let streamed_body = body
153        .collect()
154        .await
155        .expect("missing streamed chunk body")
156        .to_bytes()
157        .to_vec();
158    let stream_state = get_initial_stream_state(
159        http_request,
160        canister_id,
161        response_headers,
162        skip_verification,
163    )?;
164    let content_length = stream_state.total_length;
165
166    let body_stream = create_206_body_stream(agent.clone(), stream_state, streamed_body);
167    Ok((HttpGatewayResponseBody::Left(body_stream), content_length))
168}
169
170#[derive(Debug)]
171struct ContentRangeValues {
172    pub range_begin: usize,
173    pub range_end: usize,
174    pub total_length: usize,
175}
176
177fn parse_content_range_header_str(
178    content_range_str: &str,
179) -> Result<ContentRangeValues, AgentError> {
180    // expected format: `bytes 21010-47021/47022`
181    let str_value = content_range_str.trim();
182    if !str_value.starts_with("bytes ") {
183        return Err(AgentError::InvalidHttpResponse(format!(
184            "Invalid Content-Range header '{}'",
185            content_range_str
186        )));
187    }
188    let str_value = str_value.trim_start_matches("bytes ");
189
190    let str_value_parts = str_value.split('-').collect::<Vec<_>>();
191    if str_value_parts.len() != 2 {
192        return Err(AgentError::InvalidHttpResponse(format!(
193            "Invalid bytes spec in Content-Range header '{}'",
194            content_range_str
195        )));
196    }
197    let range_begin = str_value_parts[0].parse::<usize>().map_err(|e| {
198        AgentError::InvalidHttpResponse(format!(
199            "Invalid range_begin in '{}': {}",
200            content_range_str, e
201        ))
202    })?;
203
204    let other_value_parts = str_value_parts[1].split('/').collect::<Vec<_>>();
205    if other_value_parts.len() != 2 {
206        return Err(AgentError::InvalidHttpResponse(format!(
207            "Invalid bytes spec in Content-Range header '{}'",
208            content_range_str
209        )));
210    }
211    let range_end = other_value_parts[0].parse::<usize>().map_err(|e| {
212        AgentError::InvalidHttpResponse(format!(
213            "Invalid range_end in '{}': {}",
214            content_range_str, e
215        ))
216    })?;
217    let total_length = other_value_parts[1].parse::<usize>().map_err(|e| {
218        AgentError::InvalidHttpResponse(format!(
219            "Invalid total_length in '{}': {}",
220            content_range_str, e
221        ))
222    })?;
223
224    let rv = ContentRangeValues {
225        range_begin,
226        range_end,
227        total_length,
228    };
229    if rv.range_begin > rv.range_end
230        || rv.range_begin >= rv.total_length
231        || rv.range_end >= rv.total_length
232    {
233        Err(AgentError::InvalidHttpResponse(format!(
234            "inconsistent Content-Range header {}: {:?}",
235            content_range_str, rv
236        )))
237    } else {
238        Ok(rv)
239    }
240}
241
242fn get_content_range_header_str(
243    response_headers: &Vec<HeaderField<'static>>,
244) -> Result<String, AgentError> {
245    for HeaderField(name, value) in response_headers {
246        if name.eq_ignore_ascii_case(http::header::CONTENT_RANGE.as_ref()) {
247            return Ok(value.to_string());
248        }
249    }
250    Err(AgentError::InvalidHttpResponse(
251        "missing Content-Range header in 206 response".to_string(),
252    ))
253}
254
255fn get_content_range_values(
256    response_headers: &Vec<HeaderField<'static>>,
257    fetched_length: usize,
258) -> Result<ContentRangeValues, AgentError> {
259    let str_value = get_content_range_header_str(response_headers)?;
260    let range_values = parse_content_range_header_str(&str_value)?;
261
262    if range_values.range_begin > fetched_length {
263        return Err(AgentError::InvalidHttpResponse(format!(
264            "chunk out-of-order: range_begin={} is larger than expected begin={} ",
265            range_values.range_begin, fetched_length
266        )));
267    }
268    if range_values.range_end < fetched_length {
269        return Err(AgentError::InvalidHttpResponse(format!(
270            "chunk out-of-order: range_end={} is smaller than length fetched so far={} ",
271            range_values.range_begin, fetched_length
272        )));
273    }
274    Ok(range_values)
275}
276
277fn get_initial_stream_state<'a>(
278    http_request: HttpRequest<'a>,
279    canister_id: Principal,
280    response_headers: &Vec<HeaderField<'static>>,
281    skip_verification: bool,
282) -> Result<StreamState<'a>, AgentError> {
283    let range_values = get_content_range_values(response_headers, 0)?;
284
285    Ok(StreamState {
286        http_request,
287        canister_id,
288        total_length: range_values.total_length,
289        fetched_length: range_values
290            .range_end
291            .saturating_sub(range_values.range_begin)
292            + 1,
293        skip_verification,
294    })
295}
296
297fn create_206_body_stream(
298    agent: Agent,
299    stream_state: StreamState<'static>,
300    initial_body: Vec<u8>,
301) -> ResponseBodyStream {
302    let chunks_stream = create_206_stream(agent, Some(stream_state))
303        .map(|chunk| chunk.map(|(body, _)| Frame::data(Bytes::from(body))));
304
305    let body_stream = stream::once(async move { Ok(Frame::data(Bytes::from(initial_body))) })
306        .chain(chunks_stream)
307        .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
308        .map(|x| async move { x })
309        .buffered(STREAM_CALLBACK_BUFFER);
310
311    ResponseBodyStream::new(Box::pin(body_stream))
312}
313
314fn create_206_stream(
315    agent: Agent,
316    maybe_stream_state: Option<StreamState>,
317) -> impl Stream<Item = Result<(Vec<u8>, Option<StreamState>), AgentError>> {
318    futures::stream::try_unfold(
319        (agent, maybe_stream_state),
320        |(agent, maybe_stream_state)| async move {
321            let Some(stream_state) = maybe_stream_state else {
322                return Ok(None);
323            };
324            let canister = HttpRequestCanister::create(&agent, stream_state.canister_id);
325            let next_chunk_begin = stream_state.fetched_length;
326
327            let range_header = ("Range".to_string(), format!("bytes={}-", next_chunk_begin));
328            let mut updated_headers = stream_state.http_request.headers().to_vec();
329            updated_headers.push(range_header.clone());
330            let headers = updated_headers
331                .iter()
332                .map(|(name, value)| HeaderField(name.into(), value.into()))
333                .collect::<Vec<HeaderField>>()
334                .into_iter();
335            let query_result = canister
336                .http_request(
337                    &stream_state.http_request.method(),
338                    &stream_state.http_request.url(),
339                    headers,
340                    &stream_state.http_request.body(),
341                    Some(&u16::from(MAX_VERIFICATION_VERSION)),
342                )
343                .call()
344                .await;
345            let agent_response = match query_result {
346                Ok((response,)) => response,
347                Err(e) => return Err(e),
348            };
349            let range_values =
350                get_content_range_values(&agent_response.headers, stream_state.fetched_length)?;
351            let new_bytes_begin = stream_state
352                .fetched_length
353                .saturating_sub(range_values.range_begin);
354            let chunk_length = range_values
355                .range_end
356                .saturating_sub(stream_state.fetched_length)
357                + 1;
358            let current_fetched_length = stream_state.fetched_length + chunk_length;
359            // Verify the chunk from the range response.
360            if agent_response.streaming_strategy.is_some() {
361                return Err(AgentError::InvalidHttpResponse(
362                    "unexpected StreamingStrategy".to_string(),
363                ));
364            }
365
366            let Ok(status_code) = StatusCode::from_u16(agent_response.status_code) else {
367                return Err(AgentError::InvalidHttpResponse(format!(
368                    "Invalid canister response status code: {}",
369                    agent_response.status_code
370                )));
371            };
372            let response = HttpResponse::builder()
373                .with_status_code(status_code)
374                .with_headers(
375                    agent_response
376                        .headers
377                        .iter()
378                        .map(|HeaderField(k, v)| (k.to_string(), v.to_string()))
379                        .collect(),
380                )
381                .with_body(agent_response.body.clone())
382                .build();
383            let mut http_request = stream_state.http_request.clone();
384            http_request.headers_mut().push(range_header);
385            let validation_result = validate(
386                &agent,
387                &stream_state.canister_id,
388                http_request,
389                response,
390                stream_state.skip_verification,
391            );
392
393            if let Err(e) = validation_result {
394                return Err(AgentError::InvalidHttpResponse(format!(
395                    "CertificateVerificationFailed for a chunk starting at {}, error: {}",
396                    stream_state.fetched_length, e
397                )));
398            }
399            let maybe_new_state = if current_fetched_length < stream_state.total_length {
400                Some(StreamState {
401                    fetched_length: current_fetched_length,
402                    ..stream_state
403                })
404            } else {
405                None
406            };
407            Ok(Some((
408                (
409                    agent_response.body[new_bytes_begin..].to_vec(),
410                    maybe_new_state.clone(),
411                ),
412                (agent, maybe_new_state),
413            )))
414        },
415    )
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use assert_matches::assert_matches;
422    use std::borrow::Cow;
423
424    #[test]
425    fn should_parse_content_range_header_str() {
426        let header_values = [
427            ContentRangeValues {
428                range_begin: 0,
429                range_end: 0,
430                total_length: 1,
431            },
432            ContentRangeValues {
433                range_begin: 100,
434                range_end: 2000,
435                total_length: 3000,
436            },
437            ContentRangeValues {
438                range_begin: 10_000,
439                range_end: 300_000,
440                total_length: 500_000,
441            },
442        ];
443        for v in header_values {
444            let input = format!("bytes {}-{}/{}", v.range_begin, v.range_end, v.total_length);
445            let result = parse_content_range_header_str(&input);
446            let output = result.unwrap_or_else(|_| panic!("failed parsing '{}'", input));
447            assert_eq!(v.range_begin, output.range_begin);
448            assert_eq!(v.range_end, output.range_end);
449            assert_eq!(v.total_length, output.total_length);
450        }
451    }
452
453    #[test]
454    fn should_fail_parse_content_range_header_str_on_malformed_input() {
455        let malformed_inputs = [
456            "byte 1-2/3",
457            "bites 2-4/8",
458            "bytes 100-200/asdf",
459            "bytes 12345",
460            "something else",
461            "bytes dead-beef/123456",
462        ];
463        for input in malformed_inputs {
464            let result = parse_content_range_header_str(input);
465            assert_matches!(result, Err(e) if format!("{}", e).contains("Invalid "));
466        }
467    }
468
469    #[test]
470    fn should_fail_parse_content_range_header_str_on_inconsistent_input() {
471        let inconsistent_inputs = ["bytes 100-200/190", "bytes 200-150/400", "bytes 100-110/40"];
472        for input in inconsistent_inputs {
473            let result = parse_content_range_header_str(input);
474            assert_matches!(result, Err(e) if format!("{}", e).contains("inconsistent Content-Range header"));
475        }
476    }
477
478    #[test]
479    fn should_get_initial_stream_state() {
480        let http_request = HttpRequest::get("http://example.com/some_file")
481            .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
482            .with_body(vec![42])
483            .build();
484        let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
485        let response_headers = vec![HeaderField(
486            Cow::from("Content-Range"),
487            Cow::from("bytes 0-2/10"), // fetched 3 bytes, total length is 10
488        )];
489        let skip_verification = false;
490        let state = get_initial_stream_state(
491            http_request.clone(),
492            canister_id,
493            &response_headers,
494            skip_verification,
495        )
496        .expect("failed constructing StreamState");
497        assert_eq!(state.http_request, http_request);
498        assert_eq!(state.canister_id, canister_id);
499        assert_eq!(state.fetched_length, 3);
500        assert_eq!(state.total_length, 10);
501        assert_eq!(state.skip_verification, skip_verification);
502    }
503
504    #[test]
505    fn should_fail_get_initial_stream_state_without_content_range_header() {
506        let http_request = HttpRequest::get("http://example.com/some_file")
507            .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
508            .with_body(vec![42])
509            .build();
510        let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
511        let response_headers = vec![HeaderField(
512            Cow::from("other header"),
513            Cow::from("other value"),
514        )];
515        let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
516        assert_matches!(result, Err(e) if format!("{}", e).contains("missing Content-Range header"));
517    }
518
519    #[test]
520    fn should_fail_get_initial_stream_state_with_malformed_content_range_header() {
521        let http_request = HttpRequest::get("http://example.com/some_file")
522            .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
523            .with_body(vec![42])
524            .build();
525        let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
526        let response_headers = vec![HeaderField(
527            Cow::from("Content-Range"),
528            Cow::from("bytes 42/10"),
529        )];
530        let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
531        assert_matches!(result, Err(e) if format!("{}", e).contains("Invalid bytes spec in Content-Range header"));
532    }
533
534    #[test]
535    fn should_fail_get_initial_stream_state_with_inconsistent_content_range_header() {
536        let http_request = HttpRequest::get("http://example.com/some_file")
537            .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
538            .with_body(vec![42])
539            .build();
540        let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
541        let response_headers = vec![HeaderField(
542            Cow::from("Content-Range"),
543            Cow::from("bytes 40-100/90"),
544        )];
545        let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
546        assert_matches!(result, Err(e) if format!("{}", e).contains("inconsistent Content-Range header"));
547    }
548}