use bytes::Bytes;
use chrono::Local;
use clap::Parser;
use http_body_util::{BodyExt, Full};
use hyper::server::conn::http1;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls::ServerConfig;
use rustls_pemfile::{certs, pkcs8_private_keys};
use serde::Serialize;
use std::fs::File;
use std::io::BufReader;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
mod db_logger;
#[cfg(test)]
mod db_logger_tests;
mod logger;
#[cfg(test)]
mod logger_tests;
mod size_parser;
mod time_parser;
use db_logger::DbLogger;
use logger::Logger;
use time_parser::parse_time_string;
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value_t = 11434, env = "PROXY_OLLAMA_PORT")]
port: u16,
#[arg(
short,
long,
default_value = "http://localhost:11434",
env = "PROXY_OLLAMA_URL"
)]
ollama_url: String,
#[arg(short, long, env = "PROXY_OLLAMA_LOG_FILE")]
log_file: Option<PathBuf>,
#[arg(short, long, env = "PROXY_OLLAMA_API_KEY")]
api_key: Option<String>,
#[arg(long, env = "PROXY_OLLAMA_ALLOWED_IPS")]
allowed_ips: Option<String>,
#[cfg(feature = "database-logging")]
#[arg(long, env = "PROXY_OLLAMA_DB_URL")]
db_url: Option<String>,
#[arg(long, env = "PROXY_OLLAMA_HTTPS")]
https: bool,
#[arg(long, env = "PROXY_OLLAMA_CERT_FILE")]
cert_file: Option<PathBuf>,
#[arg(long, env = "PROXY_OLLAMA_KEY_FILE")]
key_file: Option<PathBuf>,
#[arg(long, default_value = "127.0.0.1", env = "PROXY_OLLAMA_HOST")]
host: String,
#[arg(long, default_value = "10MB", env = "PROXY_OLLAMA_LOG_ROTATE_SIZE")]
log_rotate_size: String,
#[arg(long, default_value_t = 0, env = "PROXY_OLLAMA_MAX_LOG_FILES")]
max_log_files: u32,
#[arg(long, env = "PROXY_OLLAMA_MIN_KEEP_ALIVE")]
min_keep_alive: Option<String>,
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
fn json_response<T: Serialize>(data: &T, status: StatusCode) -> Response<BoxBody> {
match serde_json::to_string(data) {
Ok(json) => Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(full(json))
.unwrap(),
Err(_) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(full("Error serializing response"))
.unwrap(),
}
}
struct OllamaConfig {
base_url: String,
logger: Arc<Logger>,
api_key: Option<String>,
allowed_ips: Option<Vec<std::net::IpAddr>>,
min_keep_alive_seconds: Option<i64>, db_logger: Arc<DbLogger>,
}
impl OllamaConfig {
fn new(
base_url: String,
logger: Arc<Logger>,
api_key: Option<String>,
allowed_ips: Option<String>,
min_keep_alive: Option<String>,
db_logger: Arc<DbLogger>,
) -> Self {
let allowed_ips = allowed_ips.map(|ips_str| {
ips_str
.split(',')
.filter_map(|ip| ip.trim().parse::<std::net::IpAddr>().ok())
.collect::<Vec<_>>()
});
let min_keep_alive_seconds =
min_keep_alive.and_then(|time_str| match parse_time_string(&time_str) {
Ok(seconds) => {
Some(seconds)
}
Err(e) => {
eprintln!("Warning: Could not parse min_keep_alive time: {e}");
None
}
});
Self {
base_url,
logger,
api_key,
allowed_ips,
min_keep_alive_seconds,
db_logger,
}
}
fn build_uri(&self, path: &str) -> Result<hyper::Uri, hyper::http::uri::InvalidUri> {
let uri_str = format!("http://{}{}", self.base_url, path);
uri_str
.parse::<hyper::Uri>()
.inspect_err(|_e| println!("Parse \"{path}\" fails"))
}
fn is_ip_allowed(&self, client_ip: &SocketAddr) -> bool {
if self.allowed_ips.is_none() {
return true;
}
if let Some(ref allowed_ips) = self.allowed_ips {
allowed_ips.contains(&client_ip.ip())
} else {
true
}
}
}
fn create_error_response(status: StatusCode, message: String) -> Response<BoxBody> {
Response::builder()
.status(status)
.body(full(message))
.unwrap()
}
fn build_ollama_uri(
path: &str,
ollama_config: &Arc<OllamaConfig>,
) -> Result<hyper::Uri, hyper::http::uri::InvalidUri> {
let uri = ollama_config.build_uri(path)?;
Ok(uri)
}
fn copy_headers(
source: &hyper::HeaderMap<hyper::header::HeaderValue>,
target: &mut hyper::HeaderMap<hyper::header::HeaderValue>,
) {
for (name, value) in source {
if name != hyper::header::AUTHORIZATION {
target.insert(name.clone(), value.clone());
}
}
}
async fn proxy_to_ollama<B>(
req: Request<B>,
path: &str,
ollama_config: &Arc<OllamaConfig>,
client_ip: &SocketAddr,
) -> Result<Response<BoxBody>, hyper::Error>
where
B: hyper::body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + std::fmt::Debug,
{
ollama_config
.logger
.log(&format!(
"Proxying request to Ollama: {} {}",
req.method(),
path
))
.await;
let (parts, body) = req.into_parts();
let maybe_body_bytes = match body.collect().await {
Ok(collected) => {
let bytes = collected.to_bytes();
log_detailed_json(
&ollama_config.logger,
"request",
&parts.method,
path,
None,
&bytes,
client_ip,
&parts.headers,
)
.await;
ollama_config
.db_logger
.log_request(client_ip, &parts.method, path, &parts.headers, &bytes)
.await;
Some(bytes)
}
Err(e) => {
let err_msg = format!("Error collecting request body: {e:?}");
ollama_config.logger.log(&err_msg).await;
None
}
};
let uri = match build_ollama_uri(path, ollama_config) {
Ok(uri) => uri,
Err(e) => {
let err_msg = format!("Error parsing URI for path {path}: {e}");
ollama_config.logger.log(&err_msg).await;
return Ok(create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid URI".to_string(),
));
}
};
let client = Client::builder(TokioExecutor::new()).build_http();
let mut request_builder = Request::builder().uri(uri).method(parts.method);
if let Some(headers) = request_builder.headers_mut() {
copy_headers(&parts.headers, headers);
}
let forwarded_req = if let Some(body_bytes) = maybe_body_bytes {
request_builder.body(Full::new(body_bytes).boxed()).unwrap()
} else {
request_builder
.body(Full::new(Bytes::new()).boxed())
.unwrap()
};
match client.request(forwarded_req).await {
Ok(ollama_resp) => {
let (parts, body) = ollama_resp.into_parts();
let status = parts.status;
match body.collect().await {
Ok(collected) => {
let bytes = collected.to_bytes();
log_detailed_json(
&ollama_config.logger,
"response",
&Method::GET, path,
Some(status),
&bytes,
client_ip,
&parts.headers,
)
.await;
ollama_config
.db_logger
.log_response(
client_ip,
&Method::GET, path,
&status,
&parts.headers,
&bytes,
)
.await;
let mut builder = Response::builder().status(status);
if let Some(headers) = builder.headers_mut() {
copy_headers(&parts.headers, headers);
}
Ok(builder.body(full(bytes)).unwrap())
}
Err(e) => {
let err_msg = format!("Error collecting Ollama response body: {e}");
ollama_config.logger.log(&err_msg).await;
Ok(create_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Error collecting response from Ollama".to_string(),
))
}
}
}
Err(e) => {
let err_msg = format!("Error forwarding request to Ollama: {e}");
ollama_config.logger.log(&err_msg).await;
Ok(create_error_response(
StatusCode::BAD_GATEWAY,
format!("Error forwarding request to Ollama: {e}"),
))
}
}
}
async fn log_request(logger: &Logger, method: &Method, path: &str, client_ip: &SocketAddr) {
let now = Local::now().format("%Y-%m-%d %H:%M:%S%.3f").to_string();
logger
.log(&format!("[{now}] {client_ip} {method} {path}"))
.await;
}
async fn log_response(
logger: &Logger,
method: &Method,
path: &str,
status: &StatusCode,
client_ip: &SocketAddr,
) {
let now = Local::now().format("%Y-%m-%d %H:%M:%S%.3f").to_string();
logger
.log(&format!(
"[{now}] {client_ip} {method} {path} - {}",
status.as_u16()
))
.await;
}
fn handle_not_found() -> Response<BoxBody> {
Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "application/json")
.body(full(r#"{"error":"Not Found"}"#))
.unwrap()
}
fn is_authenticated(
req: &Request<hyper::body::Incoming>,
ollama_config: &Arc<OllamaConfig>,
) -> bool {
if ollama_config.api_key.is_none() {
return true;
}
if let Some(auth_header) = req.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(api_key) = auth_str.strip_prefix("Bearer ") {
return Some(api_key.to_string()) == ollama_config.api_key;
}
}
}
false
}
fn handle_unauthorized() -> Response<BoxBody> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Content-Type", "application/json")
.header("WWW-Authenticate", "Bearer")
.body(full(r#"{"error":"Unauthorized - API key required"}"#))
.unwrap()
}
fn is_unload_model_request(json: &serde_json::Value) -> bool {
fn is_zero_value(value: &serde_json::Value) -> bool {
match value {
serde_json::Value::Number(n) if n.as_i64() == Some(0) => true,
serde_json::Value::String(s) => s == "0" || s == "0s" || s == "0m" || s == "0h",
_ => false,
}
}
let keep_alive_is_zero = json.get("keep_alive").is_some_and(is_zero_value)
|| json
.get("options")
.and_then(|o| o.as_object())
.and_then(|o| o.get("keep_alive"))
.is_some_and(is_zero_value);
if !keep_alive_is_zero {
return false;
}
match json.get("prompt") {
Some(prompt) => prompt.as_str().is_some_and(str::is_empty),
None => true, }
}
async fn log_model_info(json: &serde_json::Value, logger: &Logger, client_ip: &SocketAddr) {
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
let is_unload_request = is_unload_model_request(json);
if is_unload_request {
logger
.log(&format!(
"Unloading model from memory from {client_ip}: {model}"
))
.await;
} else {
logger
.log(&format!(
"Forwarding generate request to Ollama from {client_ip} for model: {model}"
))
.await;
}
}
}
fn create_generate_fallback_response() -> Response<BoxBody> {
let response = serde_json::json!({
"model": "unknown",
"created_at": Local::now().to_rfc3339(),
"response": "Mock response (Ollama server unavailable)",
"done": true
});
json_response(&response, StatusCode::OK)
}
fn is_unload_request_authenticated(
headers: &hyper::HeaderMap<hyper::header::HeaderValue>,
ollama_config: &OllamaConfig,
) -> bool {
if let Some(ref api_key) = ollama_config.api_key {
if let Some(auth_header) = headers.get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(key) = auth_str.strip_prefix("Bearer ") {
return key == api_key;
}
}
}
false
} else {
true
}
}
async fn apply_min_keep_alive_and_log(
json: &serde_json::Value,
ollama_config: &Arc<OllamaConfig>,
client_ip: &SocketAddr,
) -> Option<Bytes> {
if let Some(min_seconds) = ollama_config.min_keep_alive_seconds {
if min_seconds > 0 {
let mut modified_json = json.clone();
let mut was_modified = false;
if let Some(keep_alive) = modified_json.get("keep_alive") {
let should_modify = match keep_alive {
serde_json::Value::Number(n) => n
.as_i64()
.is_some_and(|current| current > 0 && current < min_seconds),
serde_json::Value::String(s) => parse_time_string(s)
.is_ok_and(|current| current > 0 && current < min_seconds),
_ => false,
};
if should_modify {
if let Some(obj) = modified_json.as_object_mut() {
obj.insert(
"keep_alive".to_string(),
serde_json::Value::Number(min_seconds.into()),
);
was_modified = true;
ollama_config.logger.log(&format!("Applied minimum keep_alive of {min_seconds}s to request from {client_ip}")).await;
}
}
}
if let Some(options) = modified_json.get_mut("options") {
if let Some(obj) = options.as_object_mut() {
if let Some(keep_alive) = obj.get("keep_alive") {
let should_modify_option = match keep_alive {
serde_json::Value::Number(n) => n
.as_i64()
.is_some_and(|current| current > 0 && current < min_seconds),
serde_json::Value::String(s) => parse_time_string(s)
.is_ok_and(|current| current > 0 && current < min_seconds),
_ => false,
};
if should_modify_option {
obj.insert(
"keep_alive".to_string(),
serde_json::Value::Number(min_seconds.into()),
);
was_modified = true;
ollama_config.logger.log(&format!("Applied minimum keep_alive of {min_seconds}s to options in request from {client_ip}")).await;
}
}
}
}
if was_modified {
if let Ok(new_body) = serde_json::to_vec(&modified_json) {
return Some(Bytes::from(new_body));
}
}
}
}
None
}
fn build_ollama_generate_request(uri: hyper::Uri, body: Bytes) -> Request<BoxBody> {
Request::builder()
.uri(uri)
.method(Method::POST)
.body(Full::new(body).map_err(|never| match never {}).boxed())
.expect("Failed to create request")
}
async fn handle_generate_with_model_info(
req: Request<hyper::body::Incoming>,
ollama_config: &Arc<OllamaConfig>,
path: &str,
client_ip: &SocketAddr,
) -> Response<BoxBody> {
let headers = req.headers().clone();
let (_parts, body) = req.into_parts();
let maybe_body_bytes = match body.collect().await {
Ok(collected) => Some(collected.to_bytes()),
Err(e) => {
let err_msg = format!("Error collecting request body for logging: {e}");
ollama_config.logger.log(&err_msg).await;
None
}
};
let mut modified_body = None;
if let Some(ref body_bytes) = maybe_body_bytes {
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body_bytes) {
if is_unload_model_request(&json) {
let is_auth = is_unload_request_authenticated(&headers, ollama_config);
if !is_auth {
ollama_config
.logger
.log(&format!(
"Unauthorized attempt to unload model from {client_ip}"
))
.await;
return handle_unauthorized();
}
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
ollama_config
.logger
.log(&format!(
"Unloading model from memory (authenticated) from {client_ip}: {model}"
))
.await;
}
} else {
log_model_info(&json, &ollama_config.logger, client_ip).await;
modified_body = apply_min_keep_alive_and_log(&json, ollama_config, client_ip).await;
}
}
}
let uri = ollama_config
.build_uri(path)
.expect("Failed to build URI for generate endpoint");
let req = build_ollama_generate_request(
uri,
if let Some(modified) = modified_body {
modified
} else if let Some(bytes) = maybe_body_bytes {
bytes
} else {
Bytes::new()
},
);
if let Ok(response) = proxy_to_ollama(req, path, ollama_config, client_ip).await {
response
} else {
ollama_config
.logger
.log(&format!(
"Failed to get response from Ollama for {client_ip}, using mock response"
))
.await;
create_generate_fallback_response()
}
}
async fn handle_ollama_endpoint_with_fallback<F>(
method: Method,
path: &str,
ollama_config: &Arc<OllamaConfig>,
log_message: &str,
client_ip: &SocketAddr,
fallback_generator: F,
) -> Response<BoxBody>
where
F: FnOnce() -> Response<BoxBody>,
{
ollama_config.logger.log(log_message).await;
let uri = ollama_config
.build_uri(path)
.expect("Failed to build URI for Ollama endpoint");
let req = Request::builder()
.method(method)
.uri(uri)
.body(Full::new(Bytes::new()).boxed())
.expect("Failed to create request");
if let Ok(response) = proxy_to_ollama(req, path, ollama_config, client_ip).await {
response
} else {
ollama_config
.logger
.log("Failed to get response from Ollama, using fallback response")
.await;
fallback_generator()
}
}
fn handle_proxy_error(e: &hyper::Error) -> Response<BoxBody> {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(full(format!("Error: {e}")))
.unwrap()
}
async fn handle_models_endpoint(
ollama_config: &Arc<OllamaConfig>,
client_ip: &SocketAddr,
) -> Response<BoxBody> {
handle_ollama_endpoint_with_fallback(
Method::GET,
"/api/tags",
ollama_config,
&format!("Forwarding request to list models to Ollama from {client_ip}"),
client_ip,
|| {
let models = serde_json::json!({
"models": [
{
"name": "llama2",
"modified_at": "2023-08-02T17:02:23Z",
"size": 3_791_730_298_u64,
"digest": "sha256:a2...",
"details": {
"format": "gguf",
"family": "llama",
"parameter_size": "7B",
"quantization_level": "Q4_0",
},
},
{
"name": "mistral",
"modified_at": "2023-11-20T12:15:30Z",
"size": 4_356_823_129_u64,
"digest": "sha256:b1...",
"details": {
"format": "gguf",
"family": "mistral",
"parameter_size": "7B",
"quantization_level": "Q5_K",
},
},
],
});
json_response(&models, StatusCode::OK)
},
)
.await
}
async fn handle_model_management_endpoint(
req: Request<hyper::body::Incoming>,
ollama_config: &Arc<OllamaConfig>,
path: &str,
operation: &str,
client_ip: &SocketAddr,
) -> Response<BoxBody> {
if !is_authenticated(&req, ollama_config) {
ollama_config
.logger
.log(&format!(
"Unauthorized request from {client_ip} for operation: {operation}"
))
.await;
return handle_unauthorized();
}
let operation_description = format!("model {operation}");
ollama_config
.logger
.log(&format!(
"Forwarding {operation_description} request to Ollama from {client_ip}"
))
.await;
proxy_to_ollama(req, path, ollama_config, client_ip)
.await
.unwrap_or_else(|e| handle_proxy_error(&e))
}
async fn forward_to_ollama(
req: Request<hyper::body::Incoming>,
ollama_config: &Arc<OllamaConfig>,
path: &str,
client_ip: &SocketAddr,
) -> Response<BoxBody> {
ollama_config
.logger
.log(&format!(
"Forwarding request to Ollama from {client_ip}: {path}"
))
.await;
proxy_to_ollama(req, path, ollama_config, client_ip)
.await
.unwrap_or_else(|e| handle_proxy_error(&e))
}
async fn handle_api_endpoint(
req: Request<hyper::body::Incoming>,
ollama_config: &Arc<OllamaConfig>,
path: &str,
client_ip: &SocketAddr,
) -> Response<BoxBody> {
let method = req.method().clone();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
match (method, path_parts.as_slice()) {
(Method::POST, ["api", "generate" | "chat" | "embed" | "embeddings"]) => {
handle_generate_with_model_info(req, ollama_config, path, client_ip).await
}
(Method::GET, ["api", "tags"]) => handle_models_endpoint(ollama_config, client_ip).await,
(Method::POST, ["api", "create" | "copy" | "pull" | "push"])
| (Method::DELETE, ["api", "delete"]) => {
let operation = path_parts[1];
handle_model_management_endpoint(req, ollama_config, path, operation, client_ip).await
}
_ => forward_to_ollama(req, ollama_config, path, client_ip).await,
}
}
fn handle_ip_blocked(client_ip: &SocketAddr) -> Response<BoxBody> {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Content-Type", "application/json")
.body(full(format!(
r#"{{"error":"Forbidden - IP address {client_ip} is not allowed"}}"#
)))
.unwrap()
}
fn get_client_ip(req: &Request<hyper::body::Incoming>, socket_addr: &SocketAddr) -> SocketAddr {
if let Some(forwarded_for) = req.headers().get("X-Forwarded-For") {
if let Ok(forwarded_str) = forwarded_for.to_str() {
if let Some(ip_str) = forwarded_str.split(',').next() {
if let Ok(ip) = ip_str.trim().parse::<std::net::IpAddr>() {
return SocketAddr::new(ip, socket_addr.port());
}
}
}
}
*socket_addr
}
async fn handle_request(
req: Request<hyper::body::Incoming>,
ollama_config: std::sync::Arc<OllamaConfig>,
socket_addr: SocketAddr,
) -> Result<Response<BoxBody>, hyper::Error> {
let method = req.method().clone();
let uri_path = req.uri().path().to_string();
let client_ip = get_client_ip(&req, &socket_addr);
log_request(&ollama_config.logger, &method, &uri_path, &client_ip).await;
if !ollama_config.is_ip_allowed(&client_ip) {
ollama_config
.logger
.log(&format!(
"Blocked request from unauthorized IP: {client_ip}"
))
.await;
let response = handle_ip_blocked(&client_ip);
log_response(
&ollama_config.logger,
&method,
&uri_path,
&response.status(),
&client_ip,
)
.await;
return Ok(response);
}
let response = match (method.clone(), uri_path.as_str()) {
(_, "/") => forward_to_ollama(req, &ollama_config, "/", &client_ip).await,
(_, path) if path.starts_with("/api/") => {
handle_api_endpoint(req, &ollama_config, path, &client_ip).await
}
_ => handle_not_found(),
};
log_response(
&ollama_config.logger,
&method,
&uri_path,
&response.status(),
&client_ip,
)
.await;
Ok(response)
}
fn load_tls_config(
cert_path: &PathBuf,
key_path: &PathBuf,
) -> Result<ServerConfig, Box<dyn std::error::Error>> {
let cert_file = File::open(cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let mut cert_chain = Vec::new();
for cert_result in certs(&mut cert_reader) {
let cert = cert_result?;
cert_chain.push(cert);
}
if cert_chain.is_empty() {
return Err("No certificates found in certificate file".into());
}
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
let mut private_keys = Vec::new();
for key_result in pkcs8_private_keys(&mut key_reader) {
let key = key_result?;
private_keys.push(key);
}
if private_keys.is_empty() {
return Err("No private keys found in key file".into());
}
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, private_keys.remove(0).into())?;
Ok(config)
}
async fn setup_ollama_server(
args: &Args,
logger: &Arc<Logger>,
) -> Result<Arc<OllamaConfig>, Box<dyn std::error::Error>> {
#[cfg(feature = "database-logging")]
let db_url = args.db_url.clone();
#[cfg(not(feature = "database-logging"))]
let db_url = None;
let db_logger = Arc::new(DbLogger::new(db_url).await);
let ollama_config = OllamaConfig::new(
args.ollama_url.clone(),
logger.clone(),
args.api_key.clone(),
args.allowed_ips.clone(),
args.min_keep_alive.clone(),
db_logger,
);
logger
.log(&format!(
"Forwarding requests to Ollama server at: {}",
ollama_config.base_url
))
.await;
if args.api_key.is_some() {
logger
.log("API authentication enabled for model management endpoints")
.await;
} else {
logger.log("WARNING: API authentication not configured. All endpoints are publicly accessible!").await;
}
if let Some(ref allowed_ips) = ollama_config.allowed_ips {
if allowed_ips.is_empty() {
logger
.log("WARNING: IP allowlist is empty. All requests will be blocked!")
.await;
} else {
logger
.log(&format!(
"IP address allowlist enabled. Only {} IP addresses are allowed to connect.",
allowed_ips.len()
))
.await;
if allowed_ips.len() <= 10 {
logger
.log(&format!(
"Allowed IPs: {}",
allowed_ips
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>()
.join(", ")
))
.await;
}
}
}
if let Some(ref log_path) = args.log_file {
let rotation_msg = if args.max_log_files > 0 {
format!(
"Logging to file: {} (rotation at {}, keeping max {} rotated files)",
log_path.display(),
args.log_rotate_size,
args.max_log_files
)
} else {
format!(
"Logging to file: {} (rotation at {}, no limit on rotated files)",
log_path.display(),
args.log_rotate_size
)
};
logger.log(&rotation_msg).await;
} else {
logger.log("Logging to console only (no log file)").await;
}
#[cfg(feature = "database-logging")]
if let Some(ref url) = args.db_url {
logger
.log(&format!("Database logging enabled: {url}"))
.await;
}
#[cfg(feature = "database-logging")]
if args.db_url.is_none() {
logger.log("Database logging feature is enabled but no DB URL provided. Use --db-url to enable database logging.").await;
}
#[cfg(not(feature = "database-logging"))]
logger.log("Database logging feature is disabled. Recompile with --feature=database-logging to enable.").await;
Ok(Arc::new(ollama_config))
}
async fn setup_server_listener(
addr: SocketAddr,
protocol: &str,
ollama_config: &Arc<OllamaConfig>,
logger: &Arc<Logger>,
) -> Result<TcpListener, Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
logger
.log(&format!("REST API server listening on {protocol}://{addr}"))
.await;
logger
.log(&format!(
"Root endpoint (/) forwards to Ollama server at: {}",
ollama_config.base_url
))
.await;
log_api_endpoints(logger).await;
if let Some(min_seconds) = ollama_config.min_keep_alive_seconds {
if min_seconds < 0 {
logger.log("Minimum keep_alive time set to infinite").await;
} else {
logger
.log(&format!(
"Minimum keep_alive time set to {min_seconds} seconds"
))
.await;
}
}
Ok(listener)
}
async fn run_ollama_proxy(
addr: SocketAddr,
args: Args,
logger: Arc<Logger>,
) -> Result<(), Box<dyn std::error::Error>> {
let ollama_config = setup_ollama_server(&args, &logger).await?;
let listener = setup_server_listener(addr, "http", &ollama_config, &logger).await?;
loop {
let (tcp_stream, addr) = listener.accept().await?;
logger.log(&format!("Connection from: {addr}")).await;
let io = TokioIo::new(tcp_stream);
let ollama_config = ollama_config.clone();
let client_ip = addr;
tokio::task::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let config = ollama_config.clone();
let client_addr = client_ip;
async move { handle_request(req, config, client_addr).await }
});
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
let err_msg = format!("Error serving connection: {err:?}");
eprintln!("{err_msg}");
}
});
}
}
async fn run_https_server(
addr: SocketAddr,
tls_config: ServerConfig,
args: Args,
logger: Arc<Logger>,
) -> Result<(), Box<dyn std::error::Error>> {
let ollama_config = setup_ollama_server(&args, &logger).await?;
let listener = setup_server_listener(addr, "https", &ollama_config, &logger).await?;
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
loop {
let (tcp_stream, addr) = listener.accept().await?;
logger.log(&format!("Connection from: {addr}")).await;
let tls_acceptor = tls_acceptor.clone();
let ollama_config = ollama_config.clone();
let logger = logger.clone();
tokio::task::spawn(async move {
match tls_acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
let client_ip = addr;
let service = hyper::service::service_fn(move |req| {
let config = ollama_config.clone();
let client_addr = client_ip;
async move { handle_request(req, config, client_addr).await }
});
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
let err_msg = format!("Error serving TLS connection: {err:?}");
eprintln!("{err_msg}");
}
}
Err(e) => {
if let Ok(err_msg) =
tokio::task::spawn_blocking(move || format!("TLS handshake error: {e}"))
.await
{
logger.log(&err_msg).await;
}
}
}
});
}
}
async fn log_api_endpoints(logger: &Logger) {
logger.log("API endpoints:").await;
logger
.log(" POST /api/generate - Generate text from a model")
.await;
logger
.log(" POST /api/chat - Generate the next message in a chat with a provided model")
.await;
logger
.log(" POST /api/embed - Generate embeddings from a model")
.await;
logger
.log(" POST /api/embeddings - Deprecated. Similar to /api/embed")
.await;
logger
.log(" GET /api/tags - List available models")
.await;
logger
.log(" POST /api/create - Create a new model (auth required)")
.await;
logger
.log(" POST /api/copy - Copy a model (auth required)")
.await;
logger
.log(" DELETE /api/delete - Delete a model (auth required)")
.await;
logger
.log(" POST /api/pull - Pull a model (auth required)")
.await;
logger
.log(" POST /api/push - Push a model (auth required)")
.await;
logger
.log(" Note: To unload a model, use /api/generate with empty prompt and keep_alive: 0")
.await;
logger
.log(" Note: keep_alive supports time formats like \"30s\", \"5m\", \"1h30m\", \"3h1m5s\", \"-1s\" (infinite)")
.await;
}
#[allow(clippy::too_many_arguments)]
async fn log_detailed_json(
logger: &Logger,
direction: &str,
method: &Method,
path: &str,
status: Option<StatusCode>,
body_bytes: &Bytes,
client_ip: &SocketAddr,
headers: &hyper::HeaderMap<hyper::header::HeaderValue>,
) {
let body_str = String::from_utf8_lossy(body_bytes);
let body_json = if body_bytes.is_empty() {
serde_json::json!({ "content": "<empty>" })
} else if let Ok(value) = serde_json::from_str::<serde_json::Value>(&body_str) {
value
} else {
serde_json::json!({ "content": body_str })
};
let status_code = status.map_or(0, |s| s.as_u16());
let headers_json = headers
.iter()
.map(|(k, v)| {
let key = k.as_str().to_string();
let value = v.to_str().unwrap_or("").to_string();
(key, value)
})
.collect::<std::collections::BTreeMap<_, _>>();
let log_entry = serde_json::json!({
"timestamp": Local::now().to_rfc3339(),
"direction": direction,
"client_ip": client_ip.to_string(),
"method": method.to_string(),
"path": path,
"status": status_code,
"headers": headers_json,
"body": body_json
});
if let Ok(log_json) = serde_json::to_string_pretty(&log_entry) {
logger.log(&format!("DETAILED JSON LOG: {log_json}")).await;
} else {
logger.log("Failed to serialize detailed log to JSON").await;
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let ip = match args.host.parse::<std::net::IpAddr>() {
Ok(addr) => addr,
Err(e) => {
eprintln!("Error parsing host address: {e}. Please provide a valid IP address.");
return Err(format!("Invalid host address: {}", args.host).into());
}
};
let addr = SocketAddr::new(ip, args.port);
let logger = Arc::new(
Logger::new(
args.log_file.clone(),
args.log_rotate_size.clone(),
args.max_log_files,
)
.await,
);
logger.log("Ollama Proxy Server starting up").await;
logger.log("Arguments can be provided via command line or environment variables with prefix PROXY_OLLAMA_").await;
if args.https {
if args.cert_file.is_none() || args.key_file.is_none() {
eprintln!("Error: HTTPS mode requires both --cert-file and --key-file parameters");
eprintln!("\nExample usage:");
eprintln!(
" cargo run -- --https --cert-file path/to/cert.pem --key-file path/to/key.pem"
);
eprintln!(
"\nYou can generate a self-signed certificate for testing using the provided script:"
);
eprintln!(" ./generate_cert.sh");
return Err("HTTPS mode requires both --cert-file and --key-file parameters".into());
}
let cert_file = args.cert_file.as_ref().unwrap();
let key_file = args.key_file.as_ref().unwrap();
let tls_config = match load_tls_config(cert_file, key_file) {
Ok(config) => config,
Err(e) => {
eprintln!("Error loading TLS configuration: {e}");
return Err(format!("Failed to load TLS configuration: {e}").into());
}
};
logger.log("HTTPS mode enabled").await;
logger
.log(&format!("Using certificate file: {}", cert_file.display()))
.await;
logger
.log(&format!("Using private key file: {}", key_file.display()))
.await;
run_https_server(addr, tls_config, args, logger).await?;
} else {
logger.log("HTTP mode enabled (no encryption)").await;
if args.host == "0.0.0.0" {
logger
.log("WARNING: Server is listening on all network interfaces without encryption")
.await;
}
run_ollama_proxy(addr, args, logger).await?;
}
Ok(())
}