use std::error::Error;
use std::fs::File;
use std::io;
use std::io::{Read, Seek, SeekFrom};
use std::ops::Range;
use std::path::Path;
use std::result::Result;
use reqwest::blocking::{Client, Response};
#[derive(Debug)]
enum NetFileStream {
File(File),
Http(HttpSeekableReader),
}
#[derive(Debug)]
pub struct NetFile {
stream: NetFileStream,
}
impl NetFile {
fn new(stream: NetFileStream) -> Self {
NetFile { stream }
}
fn open_file(filename: &str) -> Result<NetFile, Box<dyn Error>> {
let path = Path::new(filename);
if path.exists() && path.is_file() {
let file = File::open(path)?;
Ok(NetFile::new(NetFileStream::File(file)))
} else {
Err(Box::new(io::Error::new(
io::ErrorKind::NotFound,
"File not found",
)))
}
}
fn open_http(url: &str) -> Result<NetFile, Box<dyn Error>> {
let client = Client::new();
let head_resp = client.head(url).send()?;
if !head_resp.status().is_success() {
return Err(Box::new(io::Error::new(
io::ErrorKind::InvalidInput,
"HTTP request failed",
)));
}
let content_length = head_resp
.headers()
.get("Content-Length")
.ok_or_else(|| {
Box::new(io::Error::new(
io::ErrorKind::InvalidData,
"Missing Content-Length header",
))
})?
.to_str()
.map_err(|_| {
Box::new(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Content-Length header",
))
})?
.parse::<u64>()
.map_err(|_| {
Box::new(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Content-Length header",
))
})?;
let http_reader = HttpSeekableReader::new(client, url.to_string(), content_length);
Ok(NetFile::new(NetFileStream::Http(http_reader)))
}
pub fn open(filename: &str) -> Result<NetFile, Box<dyn Error>> {
if filename.starts_with("http://") || filename.starts_with("https://") {
NetFile::open_http(filename)
} else {
NetFile::open_file(filename)
}
}
}
impl Read for NetFile {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match &mut self.stream {
NetFileStream::File(file) => file.read(buf),
NetFileStream::Http(file) => file.read(buf),
}
}
}
impl Seek for NetFile {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
match &mut self.stream {
NetFileStream::File(file) => file.seek(pos),
NetFileStream::Http(file) => file.seek(pos),
}
}
}
const BUFFER_SIZE: usize = 8192;
#[derive(Debug)]
struct HttpSeekableReader {
client: Client,
url: String,
current_pos: u64,
content_length: u64,
buffer: Vec<u8>, buffer_start: u64, buffer_end: u64, }
impl HttpSeekableReader {
fn new(client: Client, url: String, content_length: u64) -> Self {
HttpSeekableReader {
client,
url,
current_pos: 0,
content_length,
buffer: Vec::new(),
buffer_start: 0,
buffer_end: 0,
}
}
fn get_range(&self, range: Range<u64>) -> Result<Response, reqwest::Error> {
let range_header = format!("bytes={}-{}", range.start, range.end - 1);
self.client
.get(&self.url)
.header("Range", range_header)
.send()
}
fn fill_buffer(&mut self) -> io::Result<()> {
let range_end = (self.current_pos + BUFFER_SIZE as u64).min(self.content_length);
let response = self
.get_range(self.current_pos..range_end)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let bytes = response
.bytes()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.buffer = bytes.to_vec(); self.buffer_start = self.current_pos;
self.buffer_end = self.buffer_start + self.buffer.len() as u64;
Ok(())
}
}
impl Read for HttpSeekableReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.current_pos < self.buffer_start || self.current_pos >= self.buffer_end {
self.fill_buffer()?;
}
let buffer_offset = (self.current_pos - self.buffer_start) as usize;
let available_bytes = (self.buffer_end - self.current_pos) as usize;
let bytes_to_read = buf.len().min(available_bytes);
buf[..bytes_to_read]
.copy_from_slice(&self.buffer[buffer_offset..buffer_offset + bytes_to_read]);
self.current_pos += bytes_to_read as u64;
Ok(bytes_to_read)
}
}
impl Seek for HttpSeekableReader {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos = match pos {
SeekFrom::Start(p) => p,
SeekFrom::Current(p) => {
if p >= 0 {
self.current_pos + p as u64
} else {
self.current_pos.saturating_sub((-p) as u64)
}
}
SeekFrom::End(p) => {
if p >= 0 {
self.content_length + p as u64
} else {
self.content_length.saturating_sub((-p) as u64)
}
}
};
if new_pos > self.content_length {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Seek position beyond file size",
));
}
self.current_pos = new_pos;
Ok(new_pos)
}
}