use std::collections::BTreeMap;
use anyhow::{Context, Result};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use bytes::Bytes;
use serde::Deserialize;
use agent_os_sidecar::protocol::{
OwnershipScope, RejectedResponse, RequestPayload, ResponsePayload, VmFetchRequest,
};
use crate::agent_os::AgentOs;
use crate::error::ClientError;
#[derive(Debug, Deserialize)]
struct VmFetchResponsePayload {
status: u16,
#[serde(rename = "statusText", default)]
status_text: Option<String>,
#[serde(default)]
headers: Option<Vec<(String, String)>>,
#[serde(default)]
body: Option<String>,
}
impl AgentOs {
pub async fn fetch(
&self,
port: u16,
request: http::Request<Bytes>,
) -> Result<http::Response<Bytes>> {
let (parts, body) = request.into_parts();
let path = match parts.uri.path_and_query() {
Some(pq) => pq.as_str().to_owned(),
None => "/".to_owned(),
};
let method = parts.method.as_str().to_owned();
let mut header_map: BTreeMap<String, String> = BTreeMap::new();
for (name, value) in parts.headers.iter() {
header_map.insert(
name.as_str().to_owned(),
String::from_utf8_lossy(value.as_bytes()).into_owned(),
);
}
let headers_json =
serde_json::to_string(&header_map).context("serializing fetch request headers")?;
let wire_body = if method == "GET" || method == "HEAD" {
None
} else {
Some(String::from_utf8_lossy(&body).into_owned())
};
let response = self
.transport()
.request(
self.vm_fetch_ownership(),
RequestPayload::VmFetch(VmFetchRequest {
port,
method,
path,
headers_json,
body: wire_body,
}),
)
.await?;
let response_json = match response {
ResponsePayload::VmFetchResult(result) => result.response_json,
ResponsePayload::Rejected(RejectedResponse { code, message }) => {
return Err(ClientError::Kernel { code, message }.into());
}
other => {
return Err(ClientError::Sidecar(format!(
"fetch: unexpected response {other:?}"
))
.into());
}
};
let payload: VmFetchResponsePayload =
serde_json::from_str(&response_json).context("parsing vm_fetch response json")?;
let decoded_body = match payload.body {
Some(encoded) => Bytes::from(
BASE64
.decode(encoded.as_bytes())
.context("decoding base64 fetch response body")?,
),
None => Bytes::new(),
};
let status = http::StatusCode::from_u16(payload.status)
.context("fetch: invalid response status code")?;
let mut builder = http::Response::builder().status(status);
for (key, value) in payload.headers.unwrap_or_default() {
builder = builder.header(key, value);
}
let mut http_response = builder
.body(decoded_body)
.context("building fetch response")?;
if let Some(status_text) = payload.status_text {
http_response.extensions_mut().insert(FetchStatusText(status_text));
}
Ok(http_response)
}
fn vm_fetch_ownership(&self) -> OwnershipScope {
OwnershipScope::vm(self.connection_id(), self.wire_session_id(), self.vm_id())
}
}
#[derive(Debug, Clone)]
pub struct FetchStatusText(pub String);