use std::collections::HashMap;
use std::fs::File;
use std::io::Result;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::time::{Duration, Instant};
use nix::sys::select::{select, FdSet};
use vm_memory::ByteValued;
use crate::cache::state::{BlobRangeMap, RangeMap};
use crate::device::{BlobInfo, BlobIoRange, BlobObject};
use crate::remote::connection::Endpoint;
use crate::remote::message::{
FetchRangeReply, FetchRangeRequest, FetchRangeResult, GetBlobReply, GetBlobRequest, HeaderFlag,
MsgHeader, MsgValidator, RequestCode,
};
const REQUEST_TIMEOUT_SEC: u64 = 4;
const RANGE_MAP_SHIFT: u64 = 18;
const RANGE_MAP_MASK: u64 = (1 << RANGE_MAP_SHIFT) - 1;
pub struct RemoteBlobMgr {
remote_blobs: Arc<RemoteBlobs>,
server_connection: Arc<ServerConnection>,
workdir: String,
}
impl RemoteBlobMgr {
pub fn new(workdir: String, sock: &str) -> Result<Self> {
let remote_blobs = Arc::new(RemoteBlobs::new());
let conn = ServerConnection::new(sock, remote_blobs.clone());
Ok(RemoteBlobMgr {
remote_blobs,
server_connection: Arc::new(conn),
workdir,
})
}
pub fn connect(&self) -> Result<()> {
self.server_connection.connect().map(|_| ())
}
pub fn start(&self) -> Result<()> {
ServerConnection::start(self.server_connection.clone())
}
pub fn shutdown(&self) {
self.server_connection.close();
self.remote_blobs.reset();
}
pub fn ping(&self) -> Result<()> {
self.server_connection.call_ping()
}
pub fn get_blob_object(&self, blob_info: &Arc<BlobInfo>) -> Result<Arc<dyn BlobObject>> {
if let Some(blob) = self.remote_blobs.get_blob(blob_info) {
return Ok(blob);
}
loop {
let (file, base, token) = self.server_connection.call_get_blob(blob_info)?;
let file = Arc::new(file);
let blob = RemoteBlob::new(
blob_info.clone(),
self.server_connection.clone(),
file,
base,
token,
&self.workdir,
)?;
let blob = Arc::new(blob);
if let Some(blob) = self.remote_blobs.add_blob(blob, token) {
return Ok(blob);
}
}
}
}
struct RemoteBlobs {
generation: AtomicU32,
active_blobs: Mutex<Vec<Arc<RemoteBlob>>>,
}
impl RemoteBlobs {
fn new() -> Self {
Self {
generation: AtomicU32::new(1),
active_blobs: Mutex::new(Vec::new()),
}
}
fn reset(&self) {
self.active_blobs.lock().unwrap().truncate(0);
}
fn add_blob(&self, blob: Arc<RemoteBlob>, token: u64) -> Option<Arc<RemoteBlob>> {
let mut guard = self.active_blobs.lock().unwrap();
for b in guard.iter() {
if blob.blob_info.blob_id() == b.blob_info.blob_id() {
return Some(b.clone());
}
}
if (token >> 32) as u32 == self.get_generation() {
guard.push(blob.clone());
return Some(blob);
}
None
}
fn get_blob(&self, blob_info: &Arc<BlobInfo>) -> Option<Arc<RemoteBlob>> {
let guard = self.active_blobs.lock().unwrap();
for blob in guard.iter() {
if blob.blob_info.blob_id() == blob_info.blob_id() {
return Some(blob.clone());
}
}
None
}
fn get_generation(&self) -> u32 {
self.generation.load(Ordering::Acquire)
}
fn notify_disconnect(&self) {
self.generation.fetch_add(1, Ordering::AcqRel);
for blob in self.active_blobs.lock().unwrap().iter() {
blob.token.store(0, Ordering::Release);
}
}
}
struct RemoteBlob {
blob_info: Arc<BlobInfo>,
conn: Arc<ServerConnection>,
map: Arc<BlobRangeMap>,
file: Arc<File>,
base: u64,
token: AtomicU64,
}
impl RemoteBlob {
fn new(
blob_info: Arc<BlobInfo>,
conn: Arc<ServerConnection>,
file: Arc<File>,
base: u64,
token: u64,
work_dir: &str,
) -> Result<Self> {
let blob_path = format!("{}/{}", work_dir, blob_info.blob_id());
let count = (blob_info.uncompressed_size() + RANGE_MAP_MASK) >> RANGE_MAP_SHIFT;
let map = BlobRangeMap::new(&blob_path, count as u32, RANGE_MAP_SHIFT as u32)?;
debug_assert!(count <= u32::MAX as u64);
Ok(RemoteBlob {
blob_info,
map: Arc::new(map),
conn,
file,
base,
token: AtomicU64::new(token),
})
}
}
impl AsRawFd for RemoteBlob {
fn as_raw_fd(&self) -> RawFd {
self.file.as_raw_fd()
}
}
impl BlobObject for RemoteBlob {
fn base_offset(&self) -> u64 {
self.base
}
fn is_all_data_ready(&self) -> bool {
self.map.is_range_all_ready()
}
fn fetch_range_compressed(&self, _offset: u64, _size: u64) -> Result<usize> {
Err(enosys!())
}
fn fetch_range_uncompressed(&self, offset: u64, size: u64) -> Result<usize> {
match self.map.is_range_ready(offset, size) {
Ok(true) => Ok(0),
_ => self.conn.call_fetch_range(self, offset, size),
}
}
fn prefetch_chunks(&self, _range: &BlobIoRange) -> Result<usize> {
Err(enosys!())
}
}
#[derive(Debug, Eq, PartialEq)]
enum RequestStatus {
Waiting,
Reconnect,
Timeout,
Finished,
}
#[allow(dead_code)]
enum RequestResult {
None,
Reconnect,
Noop,
GetBlob(u32, u64, u64, Option<File>),
FetchRange(u32, u64),
}
struct Request {
tag: u64,
condvar: Condvar,
state: Mutex<(RequestStatus, RequestResult)>,
}
impl Request {
fn new(tag: u64) -> Self {
Request {
tag,
condvar: Condvar::new(),
state: Mutex::new((RequestStatus::Waiting, RequestResult::None)),
}
}
fn wait_for_result(&self) {
let mut guard = self.state.lock().unwrap();
while guard.0 == RequestStatus::Waiting {
let res = self
.condvar
.wait_timeout(guard, Duration::from_secs(REQUEST_TIMEOUT_SEC))
.unwrap();
let tor = res.1;
guard = res.0;
if guard.0 == RequestStatus::Finished || guard.0 == RequestStatus::Reconnect {
return;
} else if tor.timed_out() {
guard.0 = RequestStatus::Timeout;
}
}
}
fn set_result(&self, result: RequestResult) {
let mut guard = self.state.lock().unwrap();
match guard.0 {
RequestStatus::Waiting | RequestStatus::Timeout | RequestStatus::Reconnect => {
guard.1 = result;
guard.0 = RequestStatus::Finished;
self.condvar.notify_all();
}
RequestStatus::Finished => {
debug!("received duplicated reply");
}
}
}
}
struct ServerConnection {
sock: String,
tag: AtomicU64,
exiting: AtomicBool,
conn: Mutex<Option<Endpoint>>,
ready: Condvar,
requests: Mutex<HashMap<u64, Arc<Request>>>,
remote_blobs: Arc<RemoteBlobs>,
}
impl ServerConnection {
fn new(sock: &str, remote_blobs: Arc<RemoteBlobs>) -> Self {
ServerConnection {
sock: sock.to_owned(),
tag: AtomicU64::new(1),
exiting: AtomicBool::new(false),
conn: Mutex::new(None),
ready: Condvar::new(),
requests: Mutex::new(HashMap::new()),
remote_blobs,
}
}
fn connect(&self) -> Result<bool> {
let mut guard = self.get_connection()?;
if guard.is_some() {
return Ok(false);
}
match Endpoint::connect(&self.sock) {
Ok(v) => {
*guard = Some(v);
Ok(true)
}
Err(e) => {
error!("cannot connect to remote blob manager, {}", e);
Err(eio!())
}
}
}
fn close(&self) {
if !self.exiting.swap(true, Ordering::AcqRel) {
self.disconnect();
}
}
fn start(client: Arc<ServerConnection>) -> Result<()> {
std::thread::spawn(move || loop {
match client.get_connection() {
Ok(guard) => {
if guard.is_none() {
drop(client.ready.wait(guard));
} else {
drop(guard);
}
}
Err(_) => continue,
}
let _ = client.handle_reply();
});
Ok(())
}
fn handle_reply(&self) -> Result<()> {
let mut nr;
let mut rfd = FdSet::new();
let mut efd = FdSet::new();
loop {
{
rfd.clear();
efd.clear();
match self.get_connection()?.as_ref() {
None => return Err(eio!()),
Some(conn) => {
rfd.insert(conn.as_raw_fd());
efd.insert(conn.as_raw_fd());
nr = conn.as_raw_fd() + 1;
}
}
}
let _ = select(nr, Some(&mut rfd), None, Some(&mut efd), None)
.map_err(|e| eother!(format!("{}", e)))?;
let mut guard = self.get_connection()?;
let (hdr, files) = match guard.as_mut() {
None => return Err(eio!()),
Some(conn) => conn.recv_header().map_err(|_e| eio!())?,
};
if !hdr.is_valid() {
return Err(einval!());
}
let body_size = hdr.get_size() as usize;
match hdr.get_code() {
RequestCode::MaxCommand => return Err(eother!()),
RequestCode::Noop => self.handle_result(hdr.get_tag(), RequestResult::Noop),
RequestCode::GetBlob => {
self.handle_get_blob_reply(guard, &hdr, body_size, files)?;
}
RequestCode::FetchRange => {
self.handle_fetch_range_reply(guard, &hdr, body_size, files)?;
}
}
}
}
fn call_ping(&self) -> Result<()> {
'next_iter: loop {
let req = self.create_request();
let hdr = MsgHeader::new(
req.tag,
RequestCode::Noop,
HeaderFlag::NEED_REPLY.bits(),
0u32,
);
let msg = [0u8; 0];
self.send_msg(&hdr, &msg)?;
match self.wait_for_result(&req)? {
RequestResult::Noop => return Ok(()),
RequestResult::Reconnect => continue 'next_iter,
_ => return Err(eother!()),
}
}
}
fn call_get_blob(&self, blob_info: &Arc<BlobInfo>) -> Result<(File, u64, u64)> {
if blob_info.blob_id().len() >= 256 {
return Err(einval!("blob id is too large"));
}
'next_iter: loop {
let req = self.create_request();
let hdr = MsgHeader::new(
req.tag,
RequestCode::GetBlob,
HeaderFlag::NEED_REPLY.bits(),
std::mem::size_of::<GetBlobRequest>() as u32,
);
let generation = self.remote_blobs.get_generation();
let msg = GetBlobRequest::new(generation, blob_info.blob_id());
self.send_msg(&hdr, &msg)?;
match self.wait_for_result(&req)? {
RequestResult::GetBlob(result, token, base, file) => {
if result != 0 {
return Err(std::io::Error::from_raw_os_error(result as i32));
} else if (token >> 32) as u32 != self.remote_blobs.get_generation() {
continue 'next_iter;
} else if let Some(file) = file {
return Ok((file, base, token));
} else {
return Err(einval!());
}
}
RequestResult::Reconnect => continue 'next_iter,
_ => return Err(eother!()),
}
}
}
fn call_fetch_range(&self, blob: &RemoteBlob, start: u64, count: u64) -> Result<usize> {
'next_iter: loop {
let token = blob.token.load(Ordering::Acquire);
if (token >> 32) as u32 != self.remote_blobs.get_generation() {
self.reopen_blob(blob)?;
continue 'next_iter;
}
let req = self.create_request();
let hdr = MsgHeader::new(
req.tag,
RequestCode::FetchRange,
HeaderFlag::NEED_REPLY.bits(),
std::mem::size_of::<GetBlobRequest>() as u32,
);
let msg = FetchRangeRequest::new(token, start, count);
self.send_msg(&hdr, &msg)?;
match self.wait_for_result(&req)? {
RequestResult::FetchRange(result, size) => {
if result == FetchRangeResult::Success as u32 {
return Ok(size as usize);
} else if result == FetchRangeResult::GenerationMismatch as u32 {
continue 'next_iter;
} else {
return Err(std::io::Error::from_raw_os_error(count as i32));
}
}
RequestResult::Reconnect => continue 'next_iter,
_ => return Err(eother!()),
}
}
}
fn reopen_blob(&self, blob: &RemoteBlob) -> Result<()> {
'next_iter: loop {
let req = self.create_request();
let hdr = MsgHeader::new(
req.tag,
RequestCode::GetBlob,
HeaderFlag::NEED_REPLY.bits(),
std::mem::size_of::<GetBlobRequest>() as u32,
);
let generation = self.remote_blobs.get_generation();
let msg = GetBlobRequest::new(generation, blob.blob_info.blob_id());
self.send_msg(&hdr, &msg)?;
match self.wait_for_result(&req)? {
RequestResult::GetBlob(result, token, _base, file) => {
if result != 0 {
return Err(std::io::Error::from_raw_os_error(result as i32));
} else if (token >> 32) as u32 != self.remote_blobs.get_generation() {
continue 'next_iter;
} else if let Some(_file) = file {
blob.token.store(token, Ordering::Release);
return Ok(());
} else {
return Err(einval!());
}
}
RequestResult::Reconnect => continue 'next_iter,
_ => return Err(eother!()),
}
}
}
fn get_next_tag(&self) -> u64 {
self.tag.fetch_add(1, Ordering::AcqRel)
}
fn create_request(&self) -> Arc<Request> {
let tag = self.get_next_tag();
let request = Arc::new(Request::new(tag));
self.requests.lock().unwrap().insert(tag, request.clone());
request
}
fn get_connection(&self) -> Result<MutexGuard<Option<Endpoint>>> {
if self.exiting.load(Ordering::Relaxed) {
Err(eio!())
} else {
Ok(self.conn.lock().unwrap())
}
}
fn send_msg<T: Sized>(&self, hdr: &MsgHeader, msg: &T) -> Result<()> {
if let Ok(mut guard) = self.get_connection() {
if let Some(conn) = guard.as_mut() {
if conn.send_message(hdr, msg, None).is_ok() {
return Ok(());
}
}
}
let start = Instant::now();
self.disconnect();
loop {
self.reconnect();
if let Ok(mut guard) = self.get_connection() {
if let Some(conn) = guard.as_mut() {
if conn.send_message(hdr, msg, None).is_ok() {
return Ok(());
}
}
}
self.disconnect();
if let Some(end) = start.checked_add(Duration::from_secs(REQUEST_TIMEOUT_SEC)) {
let now = Instant::now();
if end < now {
return Err(eio!());
}
} else {
return Err(eio!());
}
std::thread::sleep(Duration::from_millis(10));
}
}
fn reconnect(&self) {
if let Ok(true) = self.connect() {
let guard = self.requests.lock().unwrap();
for entry in guard.iter() {
let mut state = entry.1.state.lock().unwrap();
if state.0 == RequestStatus::Waiting {
state.0 = RequestStatus::Reconnect;
entry.1.condvar.notify_all();
}
}
}
}
fn disconnect(&self) {
self.remote_blobs.notify_disconnect();
let mut guard = self.conn.lock().unwrap();
if let Some(conn) = guard.as_mut() {
conn.close();
}
*guard = None;
}
fn wait_for_result(&self, request: &Arc<Request>) -> Result<RequestResult> {
request.wait_for_result();
let mut guard = self.requests.lock().unwrap();
match guard.remove(&request.tag) {
None => Err(enoent!()),
Some(entry) => {
let mut guard2 = entry.state.lock().unwrap();
match guard2.0 {
RequestStatus::Waiting => panic!("should not happen"),
RequestStatus::Timeout => Err(eio!()),
RequestStatus::Reconnect => Ok(RequestResult::Reconnect),
RequestStatus::Finished => {
let mut val = RequestResult::None;
mem::swap(&mut guard2.1, &mut val);
Ok(val)
}
}
}
}
}
fn handle_result(&self, tag: u64, result: RequestResult) {
let requests = self.requests.lock().unwrap();
match requests.get(&tag) {
None => debug!("no request for tag {} found, may have timed out", tag),
Some(request) => request.set_result(result),
}
}
fn handle_get_blob_reply(
&self,
mut guard: MutexGuard<Option<Endpoint>>,
hdr: &MsgHeader,
body_size: usize,
files: Option<Vec<File>>,
) -> Result<()> {
if body_size != mem::size_of::<GetBlobReply>() {
return Err(einval!());
}
let (size, data) = match guard.as_mut() {
None => return Err(einval!()),
Some(conn) => conn.recv_data(body_size).map_err(|_e| eio!())?,
};
if size != body_size {
return Err(eio!());
}
drop(guard);
let mut msg = GetBlobReply::new(0, 0, 0);
msg.as_mut_slice().copy_from_slice(&data);
if !msg.is_valid() {
return Err(einval!());
} else if msg.result != 0 {
self.handle_result(
hdr.get_tag(),
RequestResult::GetBlob(msg.result, msg.token, msg.base, None),
);
} else {
if files.is_none() {
return Err(einval!());
}
let mut files = files.unwrap();
if files.len() != 1 {
return Err(einval!());
}
let file = files.pop().unwrap();
self.handle_result(
hdr.get_tag(),
RequestResult::GetBlob(msg.result, msg.token, msg.base, Some(file)),
);
}
Ok(())
}
fn handle_fetch_range_reply(
&self,
mut guard: MutexGuard<Option<Endpoint>>,
hdr: &MsgHeader,
body_size: usize,
files: Option<Vec<File>>,
) -> Result<()> {
if body_size != mem::size_of::<FetchRangeReply>() || files.is_some() {
return Err(einval!());
}
let (size, data) = match guard.as_mut() {
None => return Err(einval!()),
Some(conn) => conn.recv_data(body_size).map_err(|_e| eio!())?,
};
if size != body_size {
return Err(eio!());
}
drop(guard);
let mut msg = FetchRangeReply::new(0, 0, 0);
msg.as_mut_slice().copy_from_slice(&data);
if !msg.is_valid() {
return Err(einval!());
} else {
self.handle_result(
hdr.get_tag(),
RequestResult::FetchRange(msg.result, msg.count),
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request() {
let req = Arc::new(Request::new(1));
let req1 = req.clone();
assert_eq!(req.tag, 1);
{
let guard = req.state.lock().unwrap();
assert_eq!(guard.0, RequestStatus::Waiting);
matches!(guard.1, RequestResult::None);
}
let (sender, receiver) = std::sync::mpsc::channel::<bool>();
std::thread::spawn(move || {
let _ = receiver.recv().unwrap();
{
let mut guard = req1.state.lock().unwrap();
guard.0 = RequestStatus::Reconnect;
}
let _ = receiver.recv().unwrap();
req1.set_result(RequestResult::Reconnect);
});
{
req.wait_for_result();
let mut guard = req.state.lock().unwrap();
assert_eq!(guard.0, RequestStatus::Timeout);
guard.0 = RequestStatus::Waiting;
}
sender.send(true).unwrap();
{
req.wait_for_result();
let mut guard = req.state.lock().unwrap();
assert_eq!(guard.0, RequestStatus::Reconnect);
guard.0 = RequestStatus::Waiting;
}
sender.send(true).unwrap();
{
req.wait_for_result();
let guard = req.state.lock().unwrap();
assert_eq!(guard.0, RequestStatus::Finished);
matches!(guard.1, RequestResult::Reconnect);
}
}
}