use std::env;
use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use super::error::BackendError;
use super::openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
use super::{Backend, Capability, ChatRequest, ChatResponse, ChatStream};
const API_KEY_ENV: &str = "GLM_API_KEY";
pub struct GLMBackend {
inner: OpenAICompatibleBackend,
}
impl GLMBackend {
pub fn from_env() -> Self {
Self::with_api_key(env::var(API_KEY_ENV).ok())
}
pub fn with_api_key(api_key: Option<String>) -> Self {
Self {
inner: OpenAICompatibleBackend::new(OpenAICompatConfig::glm(), api_key),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.inner = self.inner.with_base_url(base_url);
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.inner = self.inner.with_default_model(model);
self
}
pub fn inner(&self) -> &OpenAICompatibleBackend {
&self.inner
}
}
impl Default for GLMBackend {
fn default() -> Self {
Self::from_env()
}
}
#[async_trait]
impl Backend for GLMBackend {
fn name(&self) -> &str {
self.inner.name()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError> {
self.inner.complete(request).await
}
async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError> {
self.inner.stream(request).await
}
fn count_tokens(&self, model: &str, text: &str) -> usize {
self.inner.count_tokens(model, text)
}
fn supports(&self, capability: Capability, model: &str) -> bool {
match capability {
Capability::Vision => model.to_lowercase().starts_with("glm-4v"),
Capability::LockedParams => false,
other => self.inner.supports(other, model),
}
}
}
pub fn from_env() -> GLMBackend {
GLMBackend::from_env()
}
pub fn with_api_key(api_key: Option<String>) -> GLMBackend {
GLMBackend::with_api_key(api_key)
}
#[allow(dead_code)]
type GLMChatStream =
Pin<Box<dyn Stream<Item = Result<crate::backends::ChatChunk, BackendError>> + Send>>;
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::openai_compat::build_request_body;
use crate::backends::Message;
fn req_with(messages: Vec<Message>) -> ChatRequest {
ChatRequest {
model: String::new(),
messages,
..Default::default()
}
}
#[test]
fn from_env_constructs_glm_backend() {
let b = GLMBackend::from_env();
assert_eq!(b.name(), "glm");
assert_eq!(b.default_model(), "glm-4-plus");
}
#[test]
fn module_factory_from_env_works() {
let b = from_env();
assert_eq!(b.name(), "glm");
}
#[test]
fn module_factory_with_api_key_explicit() {
let b = with_api_key(Some("zhipu-test-key".into()));
assert_eq!(b.name(), "glm");
}
#[test]
fn with_default_model_overrides() {
let b = GLMBackend::with_api_key(Some("k".into()))
.with_default_model("glm-4-air");
assert_eq!(b.default_model(), "glm-4-air");
}
#[test]
fn with_base_url_overrides_for_test_fixtures() {
let _b = GLMBackend::with_api_key(Some("k".into()))
.with_base_url("http://localhost:9999");
}
#[test]
fn inner_accessor_returns_compat_backend() {
let b = GLMBackend::with_api_key(Some("k".into()));
assert_eq!(b.inner().name(), "glm");
}
#[test]
fn default_constructs_via_from_env() {
let b = GLMBackend::default();
assert_eq!(b.name(), "glm");
}
#[test]
fn supports_vision_for_glm_4v_family() {
let b = GLMBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::Vision, "glm-4v"));
assert!(b.supports(Capability::Vision, "glm-4v-plus"));
assert!(b.supports(Capability::Vision, "glm-4v-flash"));
}
#[test]
fn does_not_support_vision_for_chat_only_models() {
let b = GLMBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::Vision, "glm-4-plus"));
assert!(!b.supports(Capability::Vision, "glm-4-air"));
assert!(!b.supports(Capability::Vision, "glm-4-flash"));
assert!(!b.supports(Capability::Vision, "glm-3-turbo"));
}
#[test]
fn vision_is_case_insensitive() {
let b = GLMBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::Vision, "GLM-4v-plus"));
}
#[test]
fn does_not_support_lockedparams_for_any_glm_model() {
let b = GLMBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::LockedParams, "glm-4-plus"));
assert!(!b.supports(Capability::LockedParams, "glm-4-air"));
assert!(!b.supports(Capability::LockedParams, "glm-4v-plus"));
}
#[test]
fn supports_streaming_tooluse_structured_via_base() {
let b = GLMBackend::with_api_key(Some("k".into()));
let any_model = "glm-4-plus";
assert!(b.supports(Capability::Streaming, any_model));
assert!(b.supports(Capability::ToolUse, any_model));
assert!(b.supports(Capability::StructuredOutput, any_model));
}
#[test]
fn does_not_support_anthropic_or_gemini_only_caps() {
let b = GLMBackend::with_api_key(Some("k".into()));
let any_model = "glm-4-plus";
assert!(!b.supports(Capability::PromptCaching, any_model));
assert!(!b.supports(Capability::SafetySettings, any_model));
}
#[test]
fn body_keeps_sampling_params_for_glm_models() {
for model in &["glm-4-plus", "glm-4-air", "glm-4-flash", "glm-4v-plus"] {
let mut req = req_with(vec![Message::user("hi")]);
req.model = (*model).into();
req.temperature = Some(0.5);
req.top_p = Some(0.9);
let body = build_request_body(&req, "glm-4-plus", false);
assert_eq!(body["temperature"], 0.5, "model {model} should keep temperature");
assert_eq!(body["top_p"], 0.9, "model {model} should keep top_p");
}
}
#[test]
fn count_tokens_uses_cl100k_for_glm_models() {
let b = GLMBackend::with_api_key(Some("k".into()));
let n = b.count_tokens("glm-4-plus", "hello world");
assert!(n > 0);
assert!(n <= 5);
}
#[test]
fn count_tokens_uses_cl100k_for_glm_4v() {
let b = GLMBackend::with_api_key(Some("k".into()));
let n = b.count_tokens("glm-4v-plus", "hello world");
assert!(n > 0);
}
#[tokio::test]
async fn stream_delegates_to_base_real_sse_implementation() {
let b = GLMBackend::with_api_key(Some("k".into()))
.with_base_url("http://127.0.0.1:1");
match b.stream(ChatRequest::default()).await {
Err(BackendError::Generic { ref message, .. }) => {
assert!(
message.contains("streaming transport failure")
|| message.contains("transport"),
"unexpected message: {message}"
);
}
Err(other) => panic!("expected Generic, got {other:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[tokio::test]
async fn complete_without_api_key_returns_auth_error() {
let b = GLMBackend::with_api_key(None).with_base_url("http://127.0.0.1:0");
let err = b
.complete(ChatRequest {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap_err();
match err {
BackendError::Auth { api_key_env, .. } => {
assert_eq!(api_key_env.as_deref(), Some(API_KEY_ENV));
}
other => panic!("expected Auth, got {other:?}"),
}
}
}