1#![cfg(feature = "async")]
34
35use crate::client::Client;
36use crate::dry_run::DryRun;
37use crate::error::Result;
38use crate::messages::request::{CountTokensRequest, CreateMessageRequest};
39use crate::messages::response::{CountTokensResponse, Message};
40
41#[cfg(feature = "streaming")]
42use crate::messages::stream::EventStream;
43
44pub struct Messages<'a> {
49 client: &'a Client,
50}
51
52impl<'a> Messages<'a> {
53 pub(crate) fn new(client: &'a Client) -> Self {
54 Self { client }
55 }
56
57 pub async fn create(&self, request: CreateMessageRequest) -> Result<Message> {
62 self.create_with_beta(request, &[]).await
63 }
64
65 pub async fn create_with_beta(
68 &self,
69 request: CreateMessageRequest,
70 betas: &[&str],
71 ) -> Result<Message> {
72 let request_ref = &request;
73 self.client
74 .execute_with_retry(
75 || {
76 self.client
77 .request_builder(reqwest::Method::POST, "/v1/messages")
78 .json(request_ref)
79 },
80 betas,
81 )
82 .await
83 }
84
85 pub async fn count_tokens(&self, request: CountTokensRequest) -> Result<CountTokensResponse> {
88 self.count_tokens_with_beta(request, &[]).await
89 }
90
91 pub async fn count_tokens_with_beta(
93 &self,
94 request: CountTokensRequest,
95 betas: &[&str],
96 ) -> Result<CountTokensResponse> {
97 let request_ref = &request;
98 self.client
99 .execute_with_retry(
100 || {
101 self.client
102 .request_builder(reqwest::Method::POST, "/v1/messages/count_tokens")
103 .json(request_ref)
104 },
105 betas,
106 )
107 .await
108 }
109
110 pub fn dry_run(&self, request: &CreateMessageRequest) -> Result<DryRun> {
115 self.dry_run_with_beta(request, &[])
116 }
117
118 pub fn dry_run_with_beta(
121 &self,
122 request: &CreateMessageRequest,
123 betas: &[&str],
124 ) -> Result<DryRun> {
125 let builder = self
126 .client
127 .request_builder(reqwest::Method::POST, "/v1/messages")
128 .json(request);
129 self.client.render_dry_run(builder, betas)
130 }
131
132 #[cfg(feature = "pricing")]
142 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
143 pub async fn cost_preview(
144 &self,
145 request: &CreateMessageRequest,
146 pricing: &crate::pricing::PricingTable,
147 ) -> Result<crate::cost_preview::CostPreview> {
148 use crate::types::Usage;
149 let count = self.count_tokens(CountTokensRequest::from(request)).await?;
150 let input_tokens = count.input_tokens;
151 let max_output_tokens = request.max_tokens;
152 let input_cost_usd = pricing.cost(
153 &request.model,
154 &Usage {
155 input_tokens,
156 output_tokens: 0,
157 ..Usage::default()
158 },
159 );
160 let max_total_usd = pricing.cost(
161 &request.model,
162 &Usage {
163 input_tokens,
164 output_tokens: max_output_tokens,
165 ..Usage::default()
166 },
167 );
168 let max_output_cost_usd = max_total_usd - input_cost_usd;
169 Ok(crate::cost_preview::CostPreview {
170 model: request.model.clone(),
171 input_tokens,
172 max_output_tokens,
173 input_cost_usd,
174 max_output_cost_usd,
175 max_total_usd,
176 })
177 }
178
179 #[cfg(feature = "pricing")]
185 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
186 pub async fn cost_preview_cached(
187 &self,
188 request: &CreateMessageRequest,
189 pricing: &crate::pricing::PricingTable,
190 cache: &crate::cost_preview::CountTokensCache,
191 ) -> Result<crate::cost_preview::CostPreview> {
192 use crate::types::Usage;
193 let count_req = CountTokensRequest::from(request);
194 let key = crate::cost_preview::hash_request(&count_req);
195
196 let input_tokens = if let Some(cached) = cache.get(key) {
197 cached
198 } else {
199 let count = self.count_tokens(count_req).await?;
200 cache.put(key, count.input_tokens);
201 count.input_tokens
202 };
203
204 let max_output_tokens = request.max_tokens;
205 let input_cost_usd = pricing.cost(
206 &request.model,
207 &Usage {
208 input_tokens,
209 output_tokens: 0,
210 ..Usage::default()
211 },
212 );
213 let max_total_usd = pricing.cost(
214 &request.model,
215 &Usage {
216 input_tokens,
217 output_tokens: max_output_tokens,
218 ..Usage::default()
219 },
220 );
221 let max_output_cost_usd = max_total_usd - input_cost_usd;
222 Ok(crate::cost_preview::CostPreview {
223 model: request.model.clone(),
224 input_tokens,
225 max_output_tokens,
226 input_cost_usd,
227 max_output_cost_usd,
228 max_total_usd,
229 })
230 }
231
232 pub fn dry_run_count_tokens(&self, request: &CountTokensRequest) -> Result<DryRun> {
234 self.dry_run_count_tokens_with_beta(request, &[])
235 }
236
237 pub fn dry_run_count_tokens_with_beta(
239 &self,
240 request: &CountTokensRequest,
241 betas: &[&str],
242 ) -> Result<DryRun> {
243 let builder = self
244 .client
245 .request_builder(reqwest::Method::POST, "/v1/messages/count_tokens")
246 .json(request);
247 self.client.render_dry_run(builder, betas)
248 }
249
250 #[cfg(feature = "streaming")]
259 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
260 pub async fn create_stream(&self, request: CreateMessageRequest) -> Result<EventStream> {
261 self.create_stream_with_beta(request, &[]).await
262 }
263
264 #[cfg(feature = "streaming")]
266 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
267 pub async fn create_stream_with_beta(
268 &self,
269 mut request: CreateMessageRequest,
270 betas: &[&str],
271 ) -> Result<EventStream> {
272 request.stream = true;
273 let response = self
274 .client
275 .execute_streaming(
276 self.client
277 .request_builder(reqwest::Method::POST, "/v1/messages")
278 .json(&request),
279 betas,
280 )
281 .await?;
282 Ok(EventStream::from_response(response))
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::messages::input::MessageInput;
290 use crate::messages::response::Message;
291 use crate::types::{ModelId, Role, StopReason};
292 use pretty_assertions::assert_eq;
293 use serde_json::json;
294 use wiremock::matchers::{body_partial_json, header, header_exists, method, path};
295 use wiremock::{Mock, MockServer, ResponseTemplate};
296
297 fn client_for(mock: &MockServer) -> Client {
298 Client::builder()
299 .api_key("sk-ant-test")
300 .base_url(mock.uri())
301 .build()
302 .unwrap()
303 }
304
305 fn fake_response_body() -> serde_json::Value {
306 json!({
307 "id": "msg_test",
308 "type": "message",
309 "role": "assistant",
310 "content": [{"type": "text", "text": "Hi!"}],
311 "model": "claude-sonnet-4-6",
312 "stop_reason": "end_turn",
313 "usage": {"input_tokens": 5, "output_tokens": 2}
314 })
315 }
316
317 #[tokio::test]
318 async fn create_posts_to_v1_messages_with_typed_request_body() {
319 let mock = MockServer::start().await;
320 Mock::given(method("POST"))
321 .and(path("/v1/messages"))
322 .and(header("x-api-key", "sk-ant-test"))
323 .and(header("anthropic-version", crate::ANTHROPIC_VERSION))
324 .and(body_partial_json(json!({
325 "model": "claude-sonnet-4-6",
326 "max_tokens": 64,
327 "messages": [{"role": "user", "content": "hi"}]
328 })))
329 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
330 .mount(&mock)
331 .await;
332
333 let client = client_for(&mock);
334 let req = CreateMessageRequest::builder()
335 .model(ModelId::SONNET_4_6)
336 .max_tokens(64)
337 .user("hi")
338 .build()
339 .unwrap();
340 let resp = client.messages().create(req).await.unwrap();
341
342 assert_eq!(resp.id, "msg_test");
343 assert_eq!(resp.role, Role::Assistant);
344 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
345 assert_eq!(resp.usage.input_tokens, 5);
346 }
347
348 #[tokio::test]
349 async fn create_with_beta_attaches_per_request_beta_header() {
350 let mock = MockServer::start().await;
351 Mock::given(method("POST"))
352 .and(path("/v1/messages"))
353 .and(header_exists("anthropic-beta"))
354 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
355 .mount(&mock)
356 .await;
357
358 let client = client_for(&mock);
359 let req = CreateMessageRequest::builder()
360 .model(ModelId::SONNET_4_6)
361 .max_tokens(8)
362 .user("x")
363 .build()
364 .unwrap();
365
366 let _: Message = client
367 .messages()
368 .create_with_beta(req, &["computer-use-2025-01-24"])
369 .await
370 .unwrap();
371
372 let received = &mock.received_requests().await.unwrap()[0];
373 let beta = received
374 .headers
375 .get("anthropic-beta")
376 .unwrap()
377 .to_str()
378 .unwrap();
379 assert_eq!(beta, "computer-use-2025-01-24");
380 }
381
382 #[tokio::test]
383 async fn dry_run_renders_request_without_sending() {
384 let mock = MockServer::start().await;
386 let client = client_for(&mock);
387 let req = CreateMessageRequest::builder()
388 .model(ModelId::SONNET_4_6)
389 .max_tokens(64)
390 .user("hello")
391 .build()
392 .unwrap();
393
394 let dr = client.messages().dry_run(&req).unwrap();
395
396 assert_eq!(dr.method, reqwest::Method::POST);
397 assert_eq!(dr.url, format!("{}/v1/messages", mock.uri()));
398 assert_eq!(dr.headers.get("x-api-key").unwrap(), "sk-ant-test");
399 assert_eq!(
400 dr.headers.get("anthropic-version").unwrap(),
401 crate::ANTHROPIC_VERSION
402 );
403 assert_eq!(dr.body["model"], "claude-sonnet-4-6");
405 assert_eq!(dr.body["max_tokens"], 64);
406 assert_eq!(dr.body["messages"][0]["role"], "user");
407
408 assert_eq!(mock.received_requests().await.unwrap().len(), 0);
410 }
411
412 #[tokio::test]
413 async fn dry_run_with_beta_includes_anthropic_beta_header() {
414 let mock = MockServer::start().await;
415 let client = client_for(&mock);
416 let req = CreateMessageRequest::builder()
417 .model(ModelId::SONNET_4_6)
418 .max_tokens(8)
419 .user("x")
420 .build()
421 .unwrap();
422
423 let dr = client
424 .messages()
425 .dry_run_with_beta(&req, &["computer-use-2025-01-24"])
426 .unwrap();
427
428 assert_eq!(
429 dr.headers.get("anthropic-beta").unwrap(),
430 "computer-use-2025-01-24"
431 );
432 }
433
434 #[tokio::test]
435 async fn dry_run_count_tokens_uses_count_tokens_path() {
436 let mock = MockServer::start().await;
437 let client = client_for(&mock);
438 let req = CountTokensRequest::builder()
439 .model(ModelId::HAIKU_4_5)
440 .user("x")
441 .build()
442 .unwrap();
443
444 let dr = client.messages().dry_run_count_tokens(&req).unwrap();
445 assert!(dr.url.ends_with("/v1/messages/count_tokens"));
446 assert_eq!(dr.body["model"], "claude-haiku-4-5-20251001");
447 }
448
449 #[tokio::test]
450 async fn create_propagates_api_error_with_request_id() {
451 let mock = MockServer::start().await;
452 Mock::given(method("POST"))
453 .and(path("/v1/messages"))
454 .respond_with(
455 ResponseTemplate::new(400)
456 .insert_header("request-id", "req_xyz")
457 .set_body_json(json!({
458 "type": "error",
459 "error": {"type": "invalid_request_error", "message": "bad input"}
460 })),
461 )
462 .mount(&mock)
463 .await;
464
465 let client = client_for(&mock);
466 let req = CreateMessageRequest::builder()
467 .model(ModelId::SONNET_4_6)
468 .max_tokens(8)
469 .user("x")
470 .build()
471 .unwrap();
472
473 let err = client.messages().create(req).await.unwrap_err();
474 assert_eq!(err.request_id(), Some("req_xyz"));
475 assert_eq!(err.status(), Some(http::StatusCode::BAD_REQUEST));
476 }
477
478 #[tokio::test]
479 async fn count_tokens_posts_to_count_tokens_endpoint() {
480 let mock = MockServer::start().await;
481 Mock::given(method("POST"))
482 .and(path("/v1/messages/count_tokens"))
483 .and(body_partial_json(json!({
484 "model": "claude-haiku-4-5-20251001",
485 "messages": [{"role": "user", "content": "x"}]
486 })))
487 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 7})))
488 .mount(&mock)
489 .await;
490
491 let client = client_for(&mock);
492 let req = CountTokensRequest::builder()
493 .model(ModelId::HAIKU_4_5)
494 .user("x")
495 .build()
496 .unwrap();
497 let resp = client.messages().count_tokens(req).await.unwrap();
498 assert_eq!(resp.input_tokens, 7);
499 }
500
501 #[cfg(feature = "pricing")]
502 #[tokio::test]
503 async fn cost_preview_calls_count_tokens_and_computes_bounds() {
504 let mock = MockServer::start().await;
505 Mock::given(method("POST"))
506 .and(path("/v1/messages/count_tokens"))
507 .and(body_partial_json(json!({
508 "model": "claude-sonnet-4-6",
509 "messages": [{"role": "user", "content": "hi"}]
510 })))
511 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 1000})))
512 .mount(&mock)
513 .await;
514
515 let client = client_for(&mock);
516 let req = CreateMessageRequest::builder()
517 .model(ModelId::SONNET_4_6)
518 .max_tokens(2000)
519 .user("hi")
520 .build()
521 .unwrap();
522
523 let pricing = crate::pricing::PricingTable::default();
524 let preview = client
525 .messages()
526 .cost_preview(&req, &pricing)
527 .await
528 .unwrap();
529
530 assert_eq!(preview.input_tokens, 1000);
532 assert_eq!(preview.max_output_tokens, 2000);
533 assert!((preview.input_cost_usd - 0.003).abs() < 1e-9);
535 assert!((preview.max_output_cost_usd - 0.030).abs() < 1e-9);
537 assert!((preview.max_total_usd - 0.033).abs() < 1e-9);
539 }
540
541 #[cfg(feature = "pricing")]
542 #[tokio::test]
543 async fn cost_preview_cost_for_returns_point_estimate() {
544 let mock = MockServer::start().await;
545 Mock::given(method("POST"))
546 .and(path("/v1/messages/count_tokens"))
547 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 1000})))
548 .mount(&mock)
549 .await;
550
551 let client = client_for(&mock);
552 let req = CreateMessageRequest::builder()
553 .model(ModelId::SONNET_4_6)
554 .max_tokens(5000)
555 .user("hi")
556 .build()
557 .unwrap();
558
559 let pricing = crate::pricing::PricingTable::default();
560 let preview = client
561 .messages()
562 .cost_preview(&req, &pricing)
563 .await
564 .unwrap();
565 let estimate = preview.cost_for(500, &pricing);
567 assert!((estimate - 0.0105).abs() < 1e-9);
568 }
569
570 #[cfg(feature = "pricing")]
571 #[tokio::test]
572 async fn cost_preview_cached_skips_network_on_hit() {
573 let mock = MockServer::start().await;
574 Mock::given(method("POST"))
577 .and(path("/v1/messages/count_tokens"))
578 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 42})))
579 .expect(1)
580 .mount(&mock)
581 .await;
582
583 let client = client_for(&mock);
584 let req = CreateMessageRequest::builder()
585 .model(ModelId::SONNET_4_6)
586 .max_tokens(64)
587 .user("hi")
588 .build()
589 .unwrap();
590 let pricing = crate::pricing::PricingTable::default();
591 let cache = crate::cost_preview::CountTokensCache::new(8);
592
593 let p1 = client
594 .messages()
595 .cost_preview_cached(&req, &pricing, &cache)
596 .await
597 .unwrap();
598 let p2 = client
599 .messages()
600 .cost_preview_cached(&req, &pricing, &cache)
601 .await
602 .unwrap();
603
604 assert_eq!(p1, p2);
605 assert_eq!(cache.len(), 1);
606 }
607
608 #[cfg(feature = "pricing")]
609 #[tokio::test]
610 async fn cost_preview_cached_distinguishes_different_requests() {
611 let mock = MockServer::start().await;
612 Mock::given(method("POST"))
613 .and(path("/v1/messages/count_tokens"))
614 .and(body_partial_json(
615 json!({"messages": [{"role": "user", "content": "alpha"}]}),
616 ))
617 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 100})))
618 .mount(&mock)
619 .await;
620 Mock::given(method("POST"))
621 .and(path("/v1/messages/count_tokens"))
622 .and(body_partial_json(
623 json!({"messages": [{"role": "user", "content": "beta"}]}),
624 ))
625 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 200})))
626 .mount(&mock)
627 .await;
628
629 let client = client_for(&mock);
630 let pricing = crate::pricing::PricingTable::default();
631 let cache = crate::cost_preview::CountTokensCache::new(8);
632
633 let req_a = CreateMessageRequest::builder()
634 .model(ModelId::SONNET_4_6)
635 .max_tokens(64)
636 .user("alpha")
637 .build()
638 .unwrap();
639 let req_b = CreateMessageRequest::builder()
640 .model(ModelId::SONNET_4_6)
641 .max_tokens(64)
642 .user("beta")
643 .build()
644 .unwrap();
645
646 let pa = client
647 .messages()
648 .cost_preview_cached(&req_a, &pricing, &cache)
649 .await
650 .unwrap();
651 let pb = client
652 .messages()
653 .cost_preview_cached(&req_b, &pricing, &cache)
654 .await
655 .unwrap();
656
657 assert_eq!(pa.input_tokens, 100);
658 assert_eq!(pb.input_tokens, 200);
659 assert_eq!(cache.len(), 2);
660 }
661
662 #[test]
663 fn count_tokens_request_from_create_drops_max_tokens_and_sampling() {
664 let req = CreateMessageRequest::builder()
665 .model(ModelId::SONNET_4_6)
666 .max_tokens(64)
667 .temperature(0.7)
668 .user("hello")
669 .build()
670 .unwrap();
671
672 let count_req = CountTokensRequest::from(&req);
673 assert_eq!(count_req.model, req.model);
674 assert_eq!(count_req.messages.len(), 1);
675 let body = serde_json::to_value(&count_req).unwrap();
677 assert!(body.get("max_tokens").is_none());
678 assert!(body.get("temperature").is_none());
679 }
680
681 #[tokio::test]
682 async fn create_appends_assistant_prefill_in_history() {
683 let mock = MockServer::start().await;
685 Mock::given(method("POST"))
686 .and(path("/v1/messages"))
687 .and(body_partial_json(json!({
688 "messages": [
689 {"role": "user", "content": "hi"},
690 {"role": "assistant", "content": "Sure, "}
691 ]
692 })))
693 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
694 .mount(&mock)
695 .await;
696
697 let client = client_for(&mock);
698 let req = CreateMessageRequest::builder()
699 .model(ModelId::SONNET_4_6)
700 .max_tokens(8)
701 .user("hi")
702 .assistant("Sure, ")
703 .build()
704 .unwrap();
705 let _ = client.messages().create(req).await.unwrap();
706 }
707
708 #[tokio::test]
709 async fn create_retries_on_overloaded_then_succeeds() {
710 let mock = MockServer::start().await;
711 Mock::given(method("POST"))
712 .and(path("/v1/messages"))
713 .respond_with(ResponseTemplate::new(529))
714 .up_to_n_times(1)
715 .mount(&mock)
716 .await;
717 Mock::given(method("POST"))
718 .and(path("/v1/messages"))
719 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
720 .mount(&mock)
721 .await;
722
723 let client = Client::builder()
725 .api_key("sk-ant-x")
726 .base_url(mock.uri())
727 .retry(crate::retry::RetryPolicy {
728 max_attempts: 3,
729 initial_backoff: std::time::Duration::from_millis(1),
730 max_backoff: std::time::Duration::from_millis(5),
731 jitter: crate::retry::Jitter::None,
732 respect_retry_after: false,
733 })
734 .build()
735 .unwrap();
736
737 let req = CreateMessageRequest::builder()
738 .model(ModelId::SONNET_4_6)
739 .max_tokens(8)
740 .user("x")
741 .build()
742 .unwrap();
743 let resp = client.messages().create(req).await.unwrap();
744 assert_eq!(resp.id, "msg_test");
745 assert_eq!(mock.received_requests().await.unwrap().len(), 2);
746 }
747
748 #[test]
749 fn messages_namespace_borrows_client() {
750 let client = Client::new("sk-ant-x");
753 {
754 let _m = client.messages();
755 }
756 let _ = client.messages();
757
758 let _: MessageInput = MessageInput::user("x");
760 }
761
762 #[cfg(feature = "streaming")]
767 fn sse_corpus() -> &'static str {
768 concat!(
770 "event: message_start\n",
771 "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_S\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"usage\":{\"input_tokens\":3,\"output_tokens\":0}}}\n",
772 "\n",
773 "event: content_block_start\n",
774 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
775 "\n",
776 "event: ping\n",
777 "data: {\"type\":\"ping\"}\n",
778 "\n",
779 "event: content_block_delta\n",
780 "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n",
781 "\n",
782 "event: content_block_delta\n",
783 "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n",
784 "\n",
785 "event: content_block_stop\n",
786 "data: {\"type\":\"content_block_stop\",\"index\":0}\n",
787 "\n",
788 "event: message_delta\n",
789 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}\n",
790 "\n",
791 "event: message_stop\n",
792 "data: {\"type\":\"message_stop\"}\n",
793 "\n",
794 )
795 }
796
797 #[cfg(feature = "streaming")]
798 #[tokio::test]
799 async fn create_stream_aggregates_to_full_message() {
800 use crate::messages::content::{ContentBlock, KnownBlock};
801 use crate::messages::stream::EventStream;
802
803 let mock = MockServer::start().await;
804 Mock::given(method("POST"))
805 .and(path("/v1/messages"))
806 .and(body_partial_json(json!({"stream": true})))
807 .respond_with(
808 ResponseTemplate::new(200)
809 .insert_header("content-type", "text/event-stream")
810 .set_body_string(sse_corpus()),
811 )
812 .mount(&mock)
813 .await;
814
815 let client = client_for(&mock);
816 let req = CreateMessageRequest::builder()
817 .model(ModelId::SONNET_4_6)
818 .max_tokens(8)
819 .user("hi")
820 .build()
821 .unwrap();
822
823 let stream: EventStream = client.messages().create_stream(req).await.unwrap();
824 let msg = stream.aggregate().await.unwrap();
825
826 assert_eq!(msg.id, "msg_S");
827 assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
828 assert_eq!(msg.usage.output_tokens, 2);
829 assert_eq!(msg.content.len(), 1);
830 match &msg.content[0] {
831 ContentBlock::Known(KnownBlock::Text { text, .. }) => {
832 assert_eq!(text, "Hello world");
833 }
834 _ => panic!("expected text block"),
835 }
836 }
837
838 #[cfg(feature = "streaming")]
839 #[tokio::test]
840 async fn create_stream_yields_individual_events_for_iterator_use() {
841 use futures_util::StreamExt;
842
843 let mock = MockServer::start().await;
844 Mock::given(method("POST"))
845 .and(path("/v1/messages"))
846 .respond_with(
847 ResponseTemplate::new(200)
848 .insert_header("content-type", "text/event-stream")
849 .set_body_string(sse_corpus()),
850 )
851 .mount(&mock)
852 .await;
853
854 let client = client_for(&mock);
855 let req = CreateMessageRequest::builder()
856 .model(ModelId::SONNET_4_6)
857 .max_tokens(8)
858 .user("hi")
859 .build()
860 .unwrap();
861
862 let mut stream = client.messages().create_stream(req).await.unwrap();
863 let mut count = 0;
864 let mut saw_message_stop = false;
865 while let Some(ev) = stream.next().await {
866 let ev = ev.unwrap();
867 count += 1;
868 if ev.type_tag() == Some("message_stop") {
869 saw_message_stop = true;
870 }
871 }
872 assert!(saw_message_stop, "expected to see message_stop event");
873 assert!(count >= 7, "expected at least 7 events, got {count}");
874 }
875
876 #[cfg(feature = "streaming")]
877 #[tokio::test]
878 async fn create_stream_propagates_connect_error() {
879 let mock = MockServer::start().await;
880 Mock::given(method("POST"))
881 .and(path("/v1/messages"))
882 .respond_with(
883 ResponseTemplate::new(401)
884 .insert_header("request-id", "req_unauth")
885 .set_body_json(json!({
886 "type": "error",
887 "error": {"type": "authentication_error", "message": "bad key"}
888 })),
889 )
890 .mount(&mock)
891 .await;
892
893 let client = client_for(&mock);
894 let req = CreateMessageRequest::builder()
895 .model(ModelId::SONNET_4_6)
896 .max_tokens(8)
897 .user("hi")
898 .build()
899 .unwrap();
900
901 let err = client.messages().create_stream(req).await.unwrap_err();
902 assert_eq!(err.status(), Some(http::StatusCode::UNAUTHORIZED));
903 assert_eq!(err.request_id(), Some("req_unauth"));
904 }
905
906 #[cfg(feature = "streaming")]
907 #[tokio::test]
908 async fn create_stream_sets_stream_true_in_request_body() {
909 let mock = MockServer::start().await;
910 Mock::given(method("POST"))
911 .and(path("/v1/messages"))
912 .and(body_partial_json(json!({"stream": true})))
913 .respond_with(
914 ResponseTemplate::new(200)
915 .insert_header("content-type", "text/event-stream")
916 .set_body_string(sse_corpus()),
917 )
918 .mount(&mock)
919 .await;
920
921 let client = client_for(&mock);
922 let req = CreateMessageRequest::builder()
923 .model(ModelId::SONNET_4_6)
924 .max_tokens(8)
925 .user("x")
926 .build()
927 .unwrap();
928 let _ = client.messages().create_stream(req).await.unwrap();
931 }
932}