use std::sync::Arc;
use tracing::{Instrument, debug, debug_span};
use crate::client::{
ClientConfig, DOWNLOAD_SNIFF_BYTES, MAX_MEASUREMENTS_PER_UPLOAD, MAX_RESPONSE_BYTES,
MAX_UPLOAD_BYTES,
};
use crate::error::{Error, ParseError, Result};
#[cfg(feature = "csv")]
use crate::internal::endpoint::add_get_in_area_params;
use crate::internal::endpoint::{
Endpoint, add_cell_get_params, add_dump_params, add_get_in_area_size_params,
add_measurement_params, build_url, build_url_with_token, prepare_upload,
};
use crate::internal::parse::{
check_upload_response, finalize_response_body, parse_diff_listing, parse_json,
validate_dump_head,
};
use crate::internal::tracing::redact_api_key;
#[cfg(feature = "csv")]
use crate::params::GetCellsInAreaParams;
use crate::params::AreaQuery;
use crate::types::{Cell, CellCount, CellKey, DumpKind, DumpListing, MeasurementsPayload};
#[derive(Debug)]
struct Inner {
config: ClientConfig,
http: reqwest::Client,
}
#[derive(Debug, Clone)]
pub struct Client {
inner: Arc<Inner>,
}
impl Client {
pub(crate) fn from_parts(config: ClientConfig, http: reqwest::Client) -> Self {
Self { inner: Arc::new(Inner { config, http }) }
}
pub async fn get_cell(&self, key: CellKey) -> Result<Cell> {
let span = debug_span!(
"opencellid.get_cell",
mcc = key.mcc,
mnc = key.mnc,
lac = key.lac,
cell_id = key.cell_id,
);
async move {
let mut url = build_url(
&self.inner.config.base_url,
Endpoint::CellGet,
&self.inner.config.api_key,
)?;
add_cell_get_params(&mut url, &key);
self.get_json::<Cell>(url).await
}
.instrument(span)
.await
}
pub async fn get_cells_in_area_size(&self, query: AreaQuery) -> Result<CellCount> {
let span = debug_span!("opencellid.get_cells_in_area_size");
async move {
let mut url = build_url(
&self.inner.config.base_url,
Endpoint::CellGetInAreaSize,
&self.inner.config.api_key,
)?;
add_get_in_area_size_params(&mut url, &query);
self.get_json::<CellCount>(url).await
}
.instrument(span)
.await
}
#[cfg(feature = "csv")]
#[cfg_attr(docsrs, doc(cfg(feature = "csv")))]
pub async fn get_cells_in_area(&self, params: GetCellsInAreaParams) -> Result<Vec<Cell>> {
let span = debug_span!("opencellid.get_cells_in_area");
async move {
let mut url = build_url(
&self.inner.config.base_url,
Endpoint::CellGetInArea,
&self.inner.config.api_key,
)?;
add_get_in_area_params(&mut url, ¶ms);
let body = self.get_text(url).await?;
crate::internal::parse::parse_cells_csv(&body)
}
.instrument(span)
.await
}
pub async fn add_measurement(&self, m: &crate::types::Measurement) -> Result<()> {
m.validate()?;
let span = debug_span!("opencellid.add_measurement", mcc = m.mcc, mnc = m.mnc);
async move {
let mut url = build_url(
&self.inner.config.base_url,
Endpoint::MeasureAdd,
&self.inner.config.api_key,
)?;
add_measurement_params(&mut url, m);
let _ = self.get_text(url).await?;
Ok(())
}
.instrument(span)
.await
}
pub async fn upload_csv(&self, csv: impl Into<Vec<u8>>) -> Result<()> {
self.upload_multipart(Endpoint::MeasureUploadCsv, csv.into()).await
}
pub async fn upload_json(&self, payload: &MeasurementsPayload) -> Result<()> {
if payload.measurements.len() > MAX_MEASUREMENTS_PER_UPLOAD {
return Err(Error::InvalidInput(format!(
"measurements batch is {} entries, exceeds {} limit",
payload.measurements.len(),
MAX_MEASUREMENTS_PER_UPLOAD
)));
}
for m in &payload.measurements {
m.validate()?;
}
let body = serde_json::to_vec(payload)
.map_err(|e| Error::Parse(ParseError::with_source("serialise payload", e)))?;
self.upload_multipart(Endpoint::MeasureUploadJson, body).await
}
pub async fn upload_clf(&self, clf: impl Into<Vec<u8>>) -> Result<()> {
self.upload_multipart(Endpoint::MeasureUploadClf, clf.into()).await
}
async fn upload_multipart(&self, endpoint: Endpoint, body: Vec<u8>) -> Result<()> {
let prepared = prepare_upload(endpoint, body.len(), MAX_UPLOAD_BYTES)?;
let span = debug_span!(
"opencellid.upload",
endpoint = endpoint.path(),
bytes = body.len()
);
async move {
let url = build_url(&self.inner.config.base_url, endpoint, &self.inner.config.api_key)?;
debug!(url = %redact_api_key(&url), "POST multipart");
let part = reqwest::multipart::Part::bytes(body)
.file_name(prepared.filename)
.mime_str(prepared.mime)
.map_err(|e| Error::InvalidInput(format!("mime: {e}")))?;
let form = reqwest::multipart::Form::new()
.text("key", self.inner.config.api_key.to_string())
.part("datafile", part);
let resp = self.inner.http.post(url).multipart(form).send().await?;
let body = read_text_with_limit(resp).await?;
check_upload_response(&body)
}
.instrument(span)
.await
}
async fn get_text(&self, url: url::Url) -> Result<String> {
debug!(url = %redact_api_key(&url), "GET");
let resp = self.inner.http.get(url).send().await?;
read_text_with_limit(resp).await
}
async fn get_json<T: for<'de> serde::Deserialize<'de>>(&self, url: url::Url) -> Result<T> {
let body = self.get_text(url).await?;
parse_json(&body)
}
pub async fn download_dump<W>(&self, kind: DumpKind, writer: &mut W) -> Result<u64>
where
W: tokio::io::AsyncWrite + Unpin,
{
let span = debug_span!(
"opencellid.download_dump",
kind = dump_kind_tag(&kind),
bytes = tracing::field::Empty,
);
async move {
let mut url = build_url_with_token(
&self.inner.config.base_url,
Endpoint::Downloads,
&self.inner.config.api_key,
)?;
add_dump_params(&mut url, &kind)?;
debug!(url = %redact_api_key(&url), "GET dump");
let mut resp = self
.inner
.http
.get(url)
.timeout(self.inner.config.download_timeout)
.send()
.await?;
let max_dump_bytes = self.inner.config.max_dump_bytes;
if let Some(len) = resp.content_length() {
if len > max_dump_bytes {
return Err(Error::Parse(ParseError::new(format!(
"dump body advertised {len} bytes, exceeds {max_dump_bytes} limit"
))));
}
}
let status = resp.status();
let mut head: Vec<u8> = Vec::with_capacity(DOWNLOAD_SNIFF_BYTES);
while head.len() < DOWNLOAD_SNIFF_BYTES {
match resp.chunk().await? {
Some(chunk) => head.extend_from_slice(&chunk),
None => break,
}
}
validate_dump_head(status, &head)?;
use tokio::io::AsyncWriteExt as _;
writer.write_all(&head).await.map_err(|e| {
Error::Parse(ParseError::with_source("write dump body", e))
})?;
let mut total = head.len() as u64;
while let Some(chunk) = resp.chunk().await? {
if total + chunk.len() as u64 > max_dump_bytes {
return Err(Error::Parse(ParseError::new(format!(
"dump body exceeded {max_dump_bytes} byte limit"
))));
}
writer.write_all(&chunk).await.map_err(|e| {
Error::Parse(ParseError::with_source("write dump body", e))
})?;
total += chunk.len() as u64;
}
writer
.flush()
.await
.map_err(|e| Error::Parse(ParseError::with_source("flush dump body", e)))?;
tracing::Span::current().record("bytes", total);
tracing::trace!(bytes = total, "dump streamed");
Ok(total)
}
.instrument(span)
.await
}
pub async fn download_dump_to_path(
&self,
kind: DumpKind,
path: impl AsRef<std::path::Path>,
) -> Result<u64> {
let final_path = path.as_ref().to_path_buf();
let mut part_os = final_path.as_os_str().to_owned();
part_os.push(".part");
let part_path = std::path::PathBuf::from(part_os);
let mut file = tokio::fs::File::create(&part_path)
.await
.map_err(|e| Error::Parse(ParseError::with_source("create dump file", e)))?;
let result = self.download_dump(kind, &mut file).await;
match result {
Ok(n) => {
drop(file);
tokio::fs::rename(&part_path, &final_path).await.map_err(|e| {
Error::Parse(ParseError::with_source("rename dump file", e))
})?;
Ok(n)
}
Err(e) => {
drop(file);
if let Err(rm_err) = tokio::fs::remove_file(&part_path).await {
tracing::warn!(
error = %rm_err,
path = %part_path.display(),
"failed to remove partial dump"
);
}
Err(e)
}
}
}
}
fn dump_kind_tag(kind: &DumpKind) -> &'static str {
match kind {
DumpKind::World => "world",
DumpKind::Country(_) => "mcc",
DumpKind::Daily { .. } => "diff",
}
}
impl Client {
pub async fn list_daily_diffs(&self) -> Result<Vec<DumpListing>> {
let span = debug_span!("opencellid.list_daily_diffs");
async move {
let url = build_url_with_token(
&self.inner.config.base_url,
Endpoint::DownloadsList,
&self.inner.config.api_key,
)?;
let html = self.get_text(url).await?;
Ok(parse_diff_listing(&html))
}
.instrument(span)
.await
}
}
async fn read_text_with_limit(mut resp: reqwest::Response) -> Result<String> {
let status = resp.status();
let cap = resp
.content_length()
.map(|n| (n as usize).min(MAX_RESPONSE_BYTES))
.unwrap_or(8 * 1024);
if let Some(len) = resp.content_length() {
if len > MAX_RESPONSE_BYTES as u64 {
return Err(Error::Parse(ParseError::new(format!(
"response body advertised {len} bytes, exceeds {MAX_RESPONSE_BYTES} limit"
))));
}
}
let mut buf: Vec<u8> = Vec::with_capacity(cap);
while let Some(chunk) = resp.chunk().await? {
if buf.len() + chunk.len() > MAX_RESPONSE_BYTES {
return Err(Error::Parse(ParseError::new(format!(
"response body exceeded {MAX_RESPONSE_BYTES} byte limit"
))));
}
buf.extend_from_slice(&chunk);
}
finalize_response_body(status, buf)
}