use std::collections::HashSet;
use tracing::info;
use crate::compose::types::{ComposeFile, Service};
use crate::engine::Engine;
use crate::error::{ComposeError, Result};
use crate::libpod::{urlencoded, API_PREFIX};
pub(super) fn check_scale_port_conflict(
service_name: &str,
service: &Service,
replicas: usize,
) -> Result<()> {
if replicas <= 1 {
return Ok(());
}
let fixed: Vec<u16> = crate::ports::parse_ports(&service.ports)?
.iter()
.filter_map(|p| p.host_port)
.filter(|&hp| hp != 0)
.collect();
if fixed.is_empty() {
return Ok(());
}
Err(ComposeError::ScalePortConflict {
service: service_name.to_string(),
replicas,
ports: fixed,
})
}
impl Engine {
pub async fn scale(&self, file: &ComposeFile, pairs: &[(String, u32)]) -> Result<()> {
for (svc, _) in pairs {
if !file.services.contains_key(svc) {
return Err(ComposeError::ServiceNotFound(svc.clone()));
}
}
for (svc, target) in pairs {
check_scale_port_conflict(svc, &file.services[svc], *target as usize)?;
}
let targets: Vec<String> = pairs.iter().map(|(s, _)| s.clone()).collect();
self.up_with_options(file, true, &[], &targets, true, false, true)
.await?;
for (svc, target) in pairs {
self.remove_surplus_replicas(svc, &file.services[svc], *target)
.await?;
}
Ok(())
}
async fn remove_surplus_replicas(
&self,
service_name: &str,
service: &Service,
target: u32,
) -> Result<()> {
let base = self.container_name(service_name, service);
let desired: HashSet<String> = if target <= 1 {
std::iter::once(base).collect()
} else {
(1..=target).map(|i| format!("{base}-{i}")).collect()
};
let grace = self.grace_period_secs(service);
for name in self
.list_project_container_names(Some(service_name))
.await?
{
if !desired.contains(&name) {
self.stop_and_remove(&name, grace).await;
}
}
Ok(())
}
pub(super) async fn stop_and_remove(&self, name: &str, grace: i32) {
let stop_path = format!(
"{API_PREFIX}/containers/{}/stop?t={grace}",
urlencoded(name)
);
let _ = self.client.post_empty_ok(&stop_path).await;
let rm_path = format!("{API_PREFIX}/containers/{}?force=true", urlencoded(name));
if let Err(e) = self.client.delete_ok(&rm_path).await {
tracing::debug!("scale-down rm {name}: {e}");
} else {
info!("removed {name}");
}
}
pub(crate) async fn list_project_container_names(
&self,
service: Option<&str>,
) -> Result<Vec<String>> {
let mut labels = vec![format!("podup.project={}", self.project)];
if let Some(svc) = service {
labels.push(format!("podup.service={svc}"));
}
let filters = serde_json::json!({ "label": labels });
let path = format!(
"{API_PREFIX}/containers/json?all=true&filters={}",
urlencoded(&filters.to_string()),
);
let entries = self
.client
.get_json::<Vec<crate::libpod::types::container::ContainerListEntry>>(&path)
.await
.map_err(ComposeError::Podman)?;
Ok(entries
.into_iter()
.flat_map(|e| e.names)
.map(|raw| raw.trim_start_matches('/').to_string())
.collect())
}
}
#[cfg(test)]
mod tests {
use super::check_scale_port_conflict;
fn service(yaml: &str) -> crate::compose::types::Service {
let file = crate::parse_str(yaml).unwrap();
file.services.into_iter().next().unwrap().1
}
#[test]
fn single_replica_never_conflicts() {
let svc = service("services:\n web:\n image: x\n ports:\n - \"8080:80\"\n");
assert!(check_scale_port_conflict("web", &svc, 1).is_ok());
}
#[test]
fn scaled_fixed_host_port_conflicts() {
let svc = service("services:\n web:\n image: x\n ports:\n - \"8080:80\"\n");
let err = check_scale_port_conflict("web", &svc, 3).unwrap_err();
assert!(matches!(
err,
crate::error::ComposeError::ScalePortConflict { .. }
));
assert!(err.to_string().contains("8080"));
}
#[test]
fn scaled_random_host_port_is_allowed() {
let svc = service("services:\n web:\n image: x\n ports:\n - \"80\"\n");
assert!(check_scale_port_conflict("web", &svc, 3).is_ok());
}
#[test]
fn scaled_no_ports_is_allowed() {
let svc = service("services:\n worker:\n image: x\n");
assert!(check_scale_port_conflict("worker", &svc, 5).is_ok());
}
}