1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
use crate::{
aggregator::service::{Aggregator, ServiceHandle},
common::client::{ClientId, Credentials, Token},
};
use tokio::net::TcpListener;
use tracing_futures::Instrument;
use warp::{
http::{header::CONTENT_TYPE, method::Method, Response, StatusCode},
Filter,
};
pub async fn serve<A: Aggregator + 'static>(bind_address: &str, handle: ServiceHandle<A>) {
let handle = warp::any().map(move || handle.clone());
let parent_span = tracing::Span::current();
let download_global_weights = warp::get()
.and(warp::path::param::<ClientId>())
.and(warp::path::param::<Token>())
.and(handle.clone())
.and_then(move |id, token, handle: ServiceHandle<A>| {
let span =
trace_span!(parent: parent_span.clone(), "api_download_request", client_id = %id);
async move {
debug!("received download request");
match handle.download(Credentials(id, token)).await {
Ok(weights) => Ok(Response::builder().body(weights)),
Err(_) => Err(warp::reject::not_found()),
}
}
.instrument(span)
})
.with(warp::cors().allow_any_origin().allow_method(Method::GET))
.with(warp::reply::with::header(
"Content-Type",
"application/octet-stream",
));
let parent_span = tracing::Span::current();
let upload_local_weights = warp::post()
.and(warp::path::param::<ClientId>())
.and(warp::path::param::<Token>())
.and(warp::body::bytes())
.and(handle.clone())
.and_then(move |id, token, weights, handle: ServiceHandle<A>| {
let span =
trace_span!(parent: parent_span.clone(), "api_upload_request", client_id = %id);
async move {
debug!("received upload request");
match handle.upload(Credentials(id, token), weights).await {
Ok(()) => Ok(StatusCode::OK),
Err(_) => Err(warp::reject::not_found()),
}
}
.instrument(span)
})
.with(
warp::cors()
.allow_any_origin()
.allow_method(Method::POST)
.allow_header(CONTENT_TYPE),
);
let mut listener = TcpListener::bind(bind_address).await.unwrap();
info!("starting HTTP server on {}", bind_address);
let log = warp::log("http");
warp::serve(download_global_weights.or(upload_local_weights).with(log))
.run_incoming(listener.incoming())
.await
}